fischer-agentkit/tests/unit/test_tool_composition.py

304 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U3 测试: Tool 组合 (SequentialChain, ParallelFanOut, DynamicSelector) + MCPTool"""
import asyncio
import json
import pytest
from agentkit.tools.base import Tool
from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.mcp_tool import MCPTool
# ── Helper: 创建简单 FunctionTool ────────────────────────
def _make_tool(name: str, func=None, **kw):
if func is None:
async def default_func(**kwargs):
return {"tool": name, "input": kwargs}
func = default_func
return FunctionTool(name=name, description=f"Test tool {name}", func=func, **kw)
# ── SequentialChain 测试 ─────────────────────────────────
class TestSequentialChain:
async def test_chain_two_tools(self):
"""顺序执行两个工具,前一个输出作为后一个输入"""
async def step1(text: str) -> dict:
return {"processed": text.upper(), "length": len(text)}
async def step2(processed: str, length: int) -> dict:
return {"result": f"{processed}({length})"}
chain = SequentialChain(
name="upper_and_format",
description="Uppercase then format",
tools=[
FunctionTool(name="step1", description="Step 1", func=step1),
FunctionTool(name="step2", description="Step 2", func=step2),
],
)
result = await chain.execute(text="hello")
assert result["result"] == "HELLO(5)"
async def test_chain_single_tool(self):
"""单工具链直接返回结果"""
tool = _make_tool("solo")
chain = SequentialChain(name="solo_chain", description="Solo", tools=[tool])
result = await chain.execute(x=1)
assert result["tool"] == "solo"
async def test_chain_preserves_dict(self):
"""链中每步都返回 dict"""
async def add_field(x: int) -> dict:
return {"x": x, "doubled": x * 2}
async def add_label(x: int, doubled: int) -> dict:
return {"x": x, "doubled": doubled, "label": f"{x}->{doubled}"}
chain = SequentialChain(
name="enrich",
description="Enrich data",
tools=[
FunctionTool(name="add_field", description="Add", func=add_field),
FunctionTool(name="add_label", description="Label", func=add_label),
],
)
result = await chain.execute(x=3)
assert result["label"] == "3->6"
# ── ParallelFanOut 测试 ──────────────────────────────────
class TestParallelFanOut:
async def test_parallel_execution(self):
"""并行执行三个工具"""
async def search_web(query: str) -> dict:
return {"web_results": [f"web: {query}"]}
async def search_db(query: str) -> dict:
return {"db_results": [f"db: {query}"]}
async def search_cache(query: str) -> dict:
return {"cache_results": [f"cache: {query}"]}
fan_out = ParallelFanOut(
name="multi_search",
description="Search multiple sources",
tools=[
FunctionTool(name="search_web", description="Web", func=search_web),
FunctionTool(name="search_db", description="DB", func=search_db),
FunctionTool(name="search_cache", description="Cache", func=search_cache),
],
)
result = await fan_out.execute(query="AI")
assert "web_results" in result
assert "db_results" in result
assert "cache_results" in result
async def test_parallel_with_error(self):
"""并行执行中有工具失败,不影响其他"""
async def good_tool(x: int) -> dict:
return {"good": x * 2}
async def bad_tool(x: int) -> dict:
raise ValueError("boom")
fan_out = ParallelFanOut(
name="mixed",
description="Mixed results",
tools=[
FunctionTool(name="good", description="Good", func=good_tool),
FunctionTool(name="bad", description="Bad", func=bad_tool),
],
)
result = await fan_out.execute(x=5)
assert result["good"] == 10
assert "_errors" in result
assert any(e["tool"] == "bad" for e in result["_errors"])
async def test_parallel_namespace_merge(self):
"""namespace 合并策略:每个工具结果独立命名空间"""
async def tool_a(x: int) -> dict:
return {"result": "a"}
async def tool_b(x: int) -> dict:
return {"result": "b"}
fan_out = ParallelFanOut(
name="namespaced",
description="Namespaced",
tools=[
FunctionTool(name="tool_a", description="A", func=tool_a),
FunctionTool(name="tool_b", description="B", func=tool_b),
],
merge_strategy="namespace",
)
result = await fan_out.execute(x=1)
assert "tool_a" in result
assert "tool_b" in result
assert result["tool_a"]["result"] == "a"
assert result["tool_b"]["result"] == "b"
# ── DynamicSelector 测试 ─────────────────────────────────
class TestDynamicSelector:
async def test_keyword_selection_by_intent(self):
"""根据 _intent 参数选择工具"""
search_tool = _make_tool("search_engine")
calc_tool = _make_tool("calculator")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool, calc_tool],
mode="keyword",
)
result = await selector.execute(query="AI", _intent="search")
assert result["tool"] == "search_engine"
async def test_keyword_selection_by_input(self):
"""根据输入内容推断工具"""
search_tool = _make_tool("search_data")
calc_tool = _make_tool("calculate_result")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool, calc_tool],
mode="keyword",
)
result = await selector.execute(query="search for trends")
assert result["tool"] == "search_data"
async def test_no_matching_tool(self):
"""无匹配工具时返回错误"""
tool_a = _make_tool("tool_a")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[tool_a],
mode="keyword",
)
result = await selector.execute(x=999)
# 应该仍然选择 tool_abest_score=0 的默认)
assert "tool" in result or "error" in result
async def test_llm_selection(self):
"""LLM 模式选择工具"""
search_tool = _make_tool("search_engine")
calc_tool = _make_tool("calculator")
class MockLLM:
async def chat(self, messages, **kwargs):
return "0" # 选择第一个工具
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool, calc_tool],
mode="llm",
llm_client=MockLLM(),
)
result = await selector.execute(query="find information")
assert result["tool"] == "search_engine"
async def test_llm_fallback_to_keyword(self):
"""LLM 失败时回退到关键字模式"""
search_tool = _make_tool("search_engine")
class BrokenLLM:
async def chat(self, messages, **kwargs):
raise ConnectionError("LLM unavailable")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool],
mode="llm",
llm_client=BrokenLLM(),
)
result = await selector.execute(query="search something")
assert result["tool"] == "search_engine"
# ── MCPTool 测试 ─────────────────────────────────────────
class TestMCPTool:
async def test_mcp_tool_execute(self):
"""MCPTool 通过 client 调用远程工具"""
class MockMCPClient:
async def call_tool(self, name, arguments):
return {
"content": [
{"type": "text", "text": json.dumps({"status": "ok", "data": arguments})}
]
}
tool = MCPTool(
name="remote_search",
description="Remote search via MCP",
client=MockMCPClient(),
)
result = await tool.execute(query="test")
assert result["status"] == "ok"
assert result["data"]["query"] == "test"
async def test_mcp_tool_text_response(self):
"""MCP 返回非 JSON 文本"""
class MockMCPClient:
async def call_tool(self, name, arguments):
return {
"content": [
{"type": "text", "text": "plain text result"}
]
}
tool = MCPTool(
name="remote_echo",
description="Remote echo",
client=MockMCPClient(),
)
result = await tool.execute(text="hello")
assert result["result"] == "plain text result"
async def test_mcp_tool_raw_response(self):
"""MCP 返回非标准格式"""
class MockMCPClient:
async def call_tool(self, name, arguments):
return {"raw_key": "raw_value"}
tool = MCPTool(
name="remote_raw",
description="Raw response",
client=MockMCPClient(),
)
result = await tool.execute(x=1)
assert result["raw_key"] == "raw_value"