"""U3 测试: Tool 组合 (SequentialChain, ParallelFanOut, DynamicSelector) + MCPTool""" import asyncio import json import pytest from agentkit.tools.base import Tool from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain from agentkit.tools.function_tool import FunctionTool from agentkit.tools.mcp_tool import MCPTool # ── Helper: 创建简单 FunctionTool ──────────────────────── def _make_tool(name: str, func=None, **kw): if func is None: async def default_func(**kwargs): return {"tool": name, "input": kwargs} func = default_func return FunctionTool(name=name, description=f"Test tool {name}", func=func, **kw) # ── SequentialChain 测试 ───────────────────────────────── class TestSequentialChain: async def test_chain_two_tools(self): """顺序执行两个工具,前一个输出作为后一个输入""" async def step1(text: str) -> dict: return {"processed": text.upper(), "length": len(text)} async def step2(processed: str, length: int) -> dict: return {"result": f"{processed}({length})"} chain = SequentialChain( name="upper_and_format", description="Uppercase then format", tools=[ FunctionTool(name="step1", description="Step 1", func=step1), FunctionTool(name="step2", description="Step 2", func=step2), ], ) result = await chain.execute(text="hello") assert result["result"] == "HELLO(5)" async def test_chain_single_tool(self): """单工具链直接返回结果""" tool = _make_tool("solo") chain = SequentialChain(name="solo_chain", description="Solo", tools=[tool]) result = await chain.execute(x=1) assert result["tool"] == "solo" async def test_chain_preserves_dict(self): """链中每步都返回 dict""" async def add_field(x: int) -> dict: return {"x": x, "doubled": x * 2} async def add_label(x: int, doubled: int) -> dict: return {"x": x, "doubled": doubled, "label": f"{x}->{doubled}"} chain = SequentialChain( name="enrich", description="Enrich data", tools=[ FunctionTool(name="add_field", description="Add", func=add_field), FunctionTool(name="add_label", description="Label", func=add_label), ], ) result = await chain.execute(x=3) assert result["label"] == "3->6" # ── ParallelFanOut 测试 ────────────────────────────────── class TestParallelFanOut: async def test_parallel_execution(self): """并行执行三个工具""" async def search_web(query: str) -> dict: return {"web_results": [f"web: {query}"]} async def search_db(query: str) -> dict: return {"db_results": [f"db: {query}"]} async def search_cache(query: str) -> dict: return {"cache_results": [f"cache: {query}"]} fan_out = ParallelFanOut( name="multi_search", description="Search multiple sources", tools=[ FunctionTool(name="search_web", description="Web", func=search_web), FunctionTool(name="search_db", description="DB", func=search_db), FunctionTool(name="search_cache", description="Cache", func=search_cache), ], ) result = await fan_out.execute(query="AI") assert "web_results" in result assert "db_results" in result assert "cache_results" in result async def test_parallel_with_error(self): """并行执行中有工具失败,不影响其他""" async def good_tool(x: int) -> dict: return {"good": x * 2} async def bad_tool(x: int) -> dict: raise ValueError("boom") fan_out = ParallelFanOut( name="mixed", description="Mixed results", tools=[ FunctionTool(name="good", description="Good", func=good_tool), FunctionTool(name="bad", description="Bad", func=bad_tool), ], ) result = await fan_out.execute(x=5) assert result["good"] == 10 assert "_errors" in result assert any(e["tool"] == "bad" for e in result["_errors"]) async def test_parallel_namespace_merge(self): """namespace 合并策略:每个工具结果独立命名空间""" async def tool_a(x: int) -> dict: return {"result": "a"} async def tool_b(x: int) -> dict: return {"result": "b"} fan_out = ParallelFanOut( name="namespaced", description="Namespaced", tools=[ FunctionTool(name="tool_a", description="A", func=tool_a), FunctionTool(name="tool_b", description="B", func=tool_b), ], merge_strategy="namespace", ) result = await fan_out.execute(x=1) assert "tool_a" in result assert "tool_b" in result assert result["tool_a"]["result"] == "a" assert result["tool_b"]["result"] == "b" # ── DynamicSelector 测试 ───────────────────────────────── class TestDynamicSelector: async def test_keyword_selection_by_intent(self): """根据 _intent 参数选择工具""" search_tool = _make_tool("search_engine") calc_tool = _make_tool("calculator") selector = DynamicSelector( name="smart", description="Smart selector", tools=[search_tool, calc_tool], mode="keyword", ) result = await selector.execute(query="AI", _intent="search") assert result["tool"] == "search_engine" async def test_keyword_selection_by_input(self): """根据输入内容推断工具""" search_tool = _make_tool("search_data") calc_tool = _make_tool("calculate_result") selector = DynamicSelector( name="smart", description="Smart selector", tools=[search_tool, calc_tool], mode="keyword", ) result = await selector.execute(query="search for trends") assert result["tool"] == "search_data" async def test_no_matching_tool(self): """无匹配工具时返回错误""" tool_a = _make_tool("tool_a") selector = DynamicSelector( name="smart", description="Smart selector", tools=[tool_a], mode="keyword", ) result = await selector.execute(x=999) # 应该仍然选择 tool_a(best_score=0 的默认) assert "tool" in result or "error" in result async def test_llm_selection(self): """LLM 模式选择工具""" search_tool = _make_tool("search_engine") calc_tool = _make_tool("calculator") class MockLLM: async def chat(self, messages, **kwargs): return "0" # 选择第一个工具 selector = DynamicSelector( name="smart", description="Smart selector", tools=[search_tool, calc_tool], mode="llm", llm_client=MockLLM(), ) result = await selector.execute(query="find information") assert result["tool"] == "search_engine" async def test_llm_fallback_to_keyword(self): """LLM 失败时回退到关键字模式""" search_tool = _make_tool("search_engine") class BrokenLLM: async def chat(self, messages, **kwargs): raise ConnectionError("LLM unavailable") selector = DynamicSelector( name="smart", description="Smart selector", tools=[search_tool], mode="llm", llm_client=BrokenLLM(), ) result = await selector.execute(query="search something") assert result["tool"] == "search_engine" # ── MCPTool 测试 ───────────────────────────────────────── class TestMCPTool: async def test_mcp_tool_execute(self): """MCPTool 通过 client 调用远程工具""" class MockMCPClient: async def call_tool(self, name, arguments): return { "content": [ {"type": "text", "text": json.dumps({"status": "ok", "data": arguments})} ] } tool = MCPTool( name="remote_search", description="Remote search via MCP", client=MockMCPClient(), ) result = await tool.execute(query="test") assert result["status"] == "ok" assert result["data"]["query"] == "test" async def test_mcp_tool_text_response(self): """MCP 返回非 JSON 文本""" class MockMCPClient: async def call_tool(self, name, arguments): return { "content": [ {"type": "text", "text": "plain text result"} ] } tool = MCPTool( name="remote_echo", description="Remote echo", client=MockMCPClient(), ) result = await tool.execute(text="hello") assert result["result"] == "plain text result" async def test_mcp_tool_raw_response(self): """MCP 返回非标准格式""" class MockMCPClient: async def call_tool(self, name, arguments): return {"raw_key": "raw_value"} tool = MCPTool( name="remote_raw", description="Raw response", client=MockMCPClient(), ) result = await tool.execute(x=1) assert result["raw_key"] == "raw_value"