"""ToolSearchIndex / ToolSearchTool / ReAct 分层注入单元测试 测试场景: - ToolSearchIndex: 空索引、单工具、多工具相关性排序、top_k 限制、无匹配 - ToolSearchTool: 正常查询、空查询、无匹配、结果包含完整描述 - ReActEngine 分层注入: core/extended 分离、tool_search 自动添加、禁用配置、自定义 core 列表 """ from __future__ import annotations from typing import Any from unittest.mock import MagicMock import pytest from agentkit.tools.base import Tool from agentkit.tools.builtin import ToolSearchTool from agentkit.tools.search import ToolSearchIndex # ── Test Helpers ────────────────────────────────────────── class FakeTool(Tool): """用于测试的 Fake Tool""" def __init__( self, name: str, description: str, input_schema: dict[str, Any] | None = None, tags: list[str] | None = None, ): super().__init__( name=name, description=description, input_schema=input_schema, tags=tags or [], ) async def execute(self, **kwargs) -> dict: return {"status": "ok"} def _make_tools() -> list[Tool]: """创建一组测试工具""" return [ FakeTool( name="read_file", description="Read the contents of a file from the filesystem.", input_schema={ "type": "object", "properties": { "path": {"type": "string", "description": "file path to read"}, }, "required": ["path"], }, tags=["io", "file"], ), FakeTool( name="write_file", description="Write content to a file on the filesystem.", input_schema={ "type": "object", "properties": { "path": {"type": "string", "description": "file path to write"}, "content": {"type": "string", "description": "content to write"}, }, "required": ["path", "content"], }, tags=["io", "file"], ), FakeTool( name="web_search", description="Search the web for information using a search engine.", input_schema={ "type": "object", "properties": { "query": {"type": "string", "description": "search query"}, }, "required": ["query"], }, tags=["web", "search"], ), FakeTool( name="run_tests", description="Run project tests to verify code changes.", input_schema={ "type": "object", "properties": { "commands": {"type": "array", "description": "test commands"}, }, }, tags=["testing", "verification"], ), ] # ── ToolSearchIndex Tests ───────────────────────────────── class TestToolSearchIndex: """ToolSearchIndex BM25 搜索测试""" def test_empty_tools(self): """空工具列表构建索引不报错""" index = ToolSearchIndex([]) assert len(index) == 0 assert index.search("anything") == [] def test_single_tool_match(self): """单工具索引,匹配查询返回该工具""" tools = _make_tools()[:1] index = ToolSearchIndex(tools) results = index.search("read file") assert len(results) == 1 assert results[0].name == "read_file" def test_relevance_ranking(self): """多工具索引,相关工具排在前面""" tools = _make_tools() index = ToolSearchIndex(tools) results = index.search("web search") assert len(results) > 0 # web_search 应该排在最前 assert results[0].name == "web_search" def test_top_k_limit(self): """top_k 限制返回数量""" tools = _make_tools() index = ToolSearchIndex(tools) results = index.search("file", top_k=2) assert len(results) <= 2 def test_no_match_returns_empty(self): """无匹配时返回空列表""" tools = _make_tools() index = ToolSearchIndex(tools) results = index.search("xyzzy_nonexistent") assert results == [] def test_empty_query_returns_empty(self): """空查询返回空列表""" tools = _make_tools() index = ToolSearchIndex(tools) assert index.search("") == [] assert index.search(" ") == [] def test_top_k_zero_returns_empty(self): """top_k=0 返回空列表""" tools = _make_tools() index = ToolSearchIndex(tools) assert index.search("file", top_k=0) == [] def test_snake_case_tokenization(self): """snake_case 工具名被正确分词""" tool = FakeTool( name="read_file", description="Read file contents.", ) index = ToolSearchIndex([tool]) # 搜索 "read" 应该匹配 results = index.search("read") assert len(results) == 1 # 搜索 "file" 也应该匹配 results = index.search("file") assert len(results) == 1 def test_search_includes_parameter_descriptions(self): """搜索能匹配参数描述中的关键词""" tool = FakeTool( name="custom_tool", description="A custom tool.", input_schema={ "type": "object", "properties": { "database_url": { "type": "string", "description": "PostgreSQL connection string", }, }, }, ) index = ToolSearchIndex([tool]) results = index.search("postgresql database") assert len(results) == 1 assert results[0].name == "custom_tool" def test_search_includes_tags(self): """搜索能匹配标签中的关键词""" tool = FakeTool( name="data_tool", description="Process data.", tags=["etl", "pipeline"], ) index = ToolSearchIndex([tool]) results = index.search("pipeline etl") assert len(results) == 1 assert results[0].name == "data_tool" def test_invalid_k1_raises(self): """k1 < 0 抛出 ValueError""" with pytest.raises(ValueError, match="k1"): ToolSearchIndex(_make_tools(), k1=-1.0) def test_invalid_b_raises(self): """b 不在 [0,1] 范围抛出 ValueError""" with pytest.raises(ValueError, match="b"): ToolSearchIndex(_make_tools(), b=1.5) with pytest.raises(ValueError, match="b"): ToolSearchIndex(_make_tools(), b=-0.1) def test_multiple_results_sorted_by_score(self): """多个匹配结果按分数降序排列""" tools = [ FakeTool(name="search_web", description="Search the web."), FakeTool(name="search_files", description="Search files on disk."), FakeTool(name="unrelated", description="Do something unrelated."), ] index = ToolSearchIndex(tools) results = index.search("search") # 两个包含 "search" 的工具应该返回,unrelated 不返回 assert len(results) == 2 names = [r.name for r in results] assert "unrelated" not in names # ── ToolSearchTool Tests ────────────────────────────────── class TestToolSearchTool: """ToolSearchTool 工具测试""" def test_tool_name_and_schema(self): """工具名称和 schema 正确""" index = ToolSearchIndex(_make_tools()) tool = ToolSearchTool(search_index=index) assert tool.name == "tool_search" assert "query" in tool.input_schema["properties"] assert "query" in tool.input_schema["required"] async def test_execute_returns_results(self): """执行搜索返回匹配工具的完整描述""" index = ToolSearchIndex(_make_tools()) tool = ToolSearchTool(search_index=index) result = await tool.execute(query="web search") assert result["count"] > 0 assert result["query"] == "web search" first = result["results"][0] assert "name" in first assert "description" in first assert "parameters" in first assert first["name"] == "web_search" async def test_execute_empty_query_returns_error(self): """空查询返回错误""" index = ToolSearchIndex(_make_tools()) tool = ToolSearchTool(search_index=index) result = await tool.execute(query="") assert "error" in result assert result["results"] == [] async def test_execute_no_match(self): """无匹配返回空结果和提示消息""" index = ToolSearchIndex(_make_tools()) tool = ToolSearchTool(search_index=index) result = await tool.execute(query="zzz_nonexistent") assert result["count"] == 0 assert result["results"] == [] assert "message" in result async def test_execute_respects_top_k(self): """top_k 限制返回数量""" index = ToolSearchIndex(_make_tools()) tool = ToolSearchTool(search_index=index, top_k=1) result = await tool.execute(query="file") assert result["count"] <= 1 def test_invalid_top_k_raises(self): """top_k < 1 抛出 ValueError""" index = ToolSearchIndex(_make_tools()) with pytest.raises(ValueError, match="top_k"): ToolSearchTool(search_index=index, top_k=0) # ── ReActEngine Tiered Injection Tests ──────────────────── class TestReActTieredInjection: """ReActEngine 工具描述分层注入测试""" def _make_engine(self, **kwargs: Any): from agentkit.core.react import ReActEngine gateway = MagicMock() return ReActEngine(llm_gateway=gateway, **kwargs) def test_core_tools_get_full_description(self): """Core 工具注入完整描述(含参数)""" engine = self._make_engine() tools = [ FakeTool( name="read_file", description="Read a file.", input_schema={ "type": "object", "properties": { "path": {"type": "string", "description": "file path"}, }, "required": ["path"], }, ), ] prompt = engine._build_tool_use_prompt(tools) # 核心工具区域存在 assert "核心工具" in prompt # 参数描述被注入 assert "path" in prompt assert "file path" in prompt def test_extended_tools_get_one_line_only(self): """Extended 工具只注入名称+一行描述(无参数)""" engine = self._make_engine() tools = [ FakeTool( name="custom_extended", description="A custom extended tool for testing.", input_schema={ "type": "object", "properties": { "secret_param": { "type": "string", "description": "SECRET_PARAM_DESC", }, }, }, ), ] prompt = engine._build_tool_use_prompt(tools) assert "扩展工具" in prompt assert "custom_extended" in prompt # 参数描述不应出现在扩展工具区域 assert "SECRET_PARAM_DESC" not in prompt def test_mixed_core_and_extended(self): """混合 core + extended 工具,两者分区显示""" engine = self._make_engine() tools = [ FakeTool(name="read_file", description="Read a file."), FakeTool( name="web_search", description="Search the web.", input_schema={ "type": "object", "properties": { "q": {"type": "string", "description": "query string"}, }, }, ), ] prompt = engine._build_tool_use_prompt(tools) assert "核心工具" in prompt assert "扩展工具" in prompt # read_file 在核心区,web_search 在扩展区 assert "read_file" in prompt assert "web_search" in prompt def test_maybe_add_tool_search_adds_for_extended(self): """有扩展工具时自动添加 tool_search""" engine = self._make_engine() tools = [ FakeTool(name="read_file", description="Read a file."), # core FakeTool(name="web_search", description="Search the web."), # extended ] result = engine._maybe_add_tool_search(tools) assert len(result) == 3 assert any(t.name == "tool_search" for t in result) def test_maybe_add_tool_search_skips_when_only_core(self): """只有 core 工具时不添加 tool_search""" engine = self._make_engine() tools = [ FakeTool(name="read_file", description="Read a file."), FakeTool(name="str_replace_editor", description="Edit a file."), ] result = engine._maybe_add_tool_search(tools) assert len(result) == 2 assert not any(t.name == "tool_search" for t in result) def test_maybe_add_tool_search_skips_when_disabled(self): """enable_tool_search=False 时不添加 tool_search""" engine = self._make_engine(enable_tool_search=False) tools = [ FakeTool(name="read_file", description="Read a file."), FakeTool(name="web_search", description="Search the web."), ] result = engine._maybe_add_tool_search(tools) assert len(result) == 2 assert not any(t.name == "tool_search" for t in result) def test_maybe_add_tool_search_skips_when_already_present(self): """tool_search 已存在时不重复添加""" engine = self._make_engine() index = ToolSearchIndex([]) existing_search = ToolSearchTool(search_index=index) tools = [ FakeTool(name="read_file", description="Read a file."), FakeTool(name="web_search", description="Search the web."), existing_search, ] result = engine._maybe_add_tool_search(tools) assert len(result) == 3 def test_custom_core_tool_names(self): """自定义 core_tool_names 覆盖默认值""" engine = self._make_engine(core_tool_names=["my_core_tool"]) tools = [ FakeTool(name="my_core_tool", description="My core tool."), FakeTool( name="read_file", description="Read a file.", input_schema={ "type": "object", "properties": { "p": {"type": "string", "description": "PARAM_DESC"}, }, }, ), ] prompt = engine._build_tool_use_prompt(tools) # my_core_tool 在核心区 assert "核心工具" in prompt # read_file 现在是扩展工具,参数描述不应出现 assert "PARAM_DESC" not in prompt def test_tool_search_is_core_tool(self): """tool_search 被视为 core 工具(全量描述注入)""" engine = self._make_engine() index = ToolSearchIndex([]) search_tool = ToolSearchTool(search_index=index) tools = [search_tool] prompt = engine._build_tool_use_prompt(tools) # tool_search 应该在核心工具区 assert "核心工具" in prompt assert "tool_search" in prompt # 其参数 query 应该被注入 assert "query" in prompt def test_search_hint_added_when_tool_search_present(self): """tool_search 存在且有扩展工具时添加搜索提示""" engine = self._make_engine() tools = [ FakeTool(name="read_file", description="Read a file."), FakeTool(name="web_search", description="Search the web."), ] # 模拟 _maybe_add_tool_search 添加 tool_search tools_with_search = engine._maybe_add_tool_search(tools) prompt = engine._build_tool_use_prompt(tools_with_search) assert "tool_search" in prompt assert "扩展工具" in prompt # 搜索提示存在 assert "tool_search(query" in prompt def test_no_search_hint_when_no_extended_tools(self): """无扩展工具时不添加搜索提示""" engine = self._make_engine() tools = [ FakeTool(name="read_file", description="Read a file."), ] prompt = engine._build_tool_use_prompt(tools) assert "扩展工具" not in prompt