fischer-agentkit/tests/unit/tools/test_tool_search.py

473 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ToolSearchIndex / ToolSearchTool / ReAct 分层注入单元测试
测试场景:
- ToolSearchIndex: 空索引、单工具、多工具相关性排序、top_k 限制、无匹配
- ToolSearchTool: 正常查询、空查询、无匹配、结果包含完整描述
- ReActEngine 分层注入: core/extended 分离、tool_search 自动添加、禁用配置、自定义 core 列表
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from agentkit.tools.base import Tool
from agentkit.tools.builtin import ToolSearchTool
from agentkit.tools.search import ToolSearchIndex
# ── Test Helpers ──────────────────────────────────────────
class FakeTool(Tool):
"""用于测试的 Fake Tool"""
def __init__(
self,
name: str,
description: str,
input_schema: dict[str, Any] | None = None,
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema,
tags=tags or [],
)
async def execute(self, **kwargs) -> dict:
return {"status": "ok"}
def _make_tools() -> list[Tool]:
"""创建一组测试工具"""
return [
FakeTool(
name="read_file",
description="Read the contents of a file from the filesystem.",
input_schema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "file path to read"},
},
"required": ["path"],
},
tags=["io", "file"],
),
FakeTool(
name="write_file",
description="Write content to a file on the filesystem.",
input_schema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "file path to write"},
"content": {"type": "string", "description": "content to write"},
},
"required": ["path", "content"],
},
tags=["io", "file"],
),
FakeTool(
name="web_search",
description="Search the web for information using a search engine.",
input_schema={
"type": "object",
"properties": {
"query": {"type": "string", "description": "search query"},
},
"required": ["query"],
},
tags=["web", "search"],
),
FakeTool(
name="run_tests",
description="Run project tests to verify code changes.",
input_schema={
"type": "object",
"properties": {
"commands": {"type": "array", "description": "test commands"},
},
},
tags=["testing", "verification"],
),
]
# ── ToolSearchIndex Tests ─────────────────────────────────
class TestToolSearchIndex:
"""ToolSearchIndex BM25 搜索测试"""
def test_empty_tools(self):
"""空工具列表构建索引不报错"""
index = ToolSearchIndex([])
assert len(index) == 0
assert index.search("anything") == []
def test_single_tool_match(self):
"""单工具索引,匹配查询返回该工具"""
tools = _make_tools()[:1]
index = ToolSearchIndex(tools)
results = index.search("read file")
assert len(results) == 1
assert results[0].name == "read_file"
def test_relevance_ranking(self):
"""多工具索引,相关工具排在前面"""
tools = _make_tools()
index = ToolSearchIndex(tools)
results = index.search("web search")
assert len(results) > 0
# web_search 应该排在最前
assert results[0].name == "web_search"
def test_top_k_limit(self):
"""top_k 限制返回数量"""
tools = _make_tools()
index = ToolSearchIndex(tools)
results = index.search("file", top_k=2)
assert len(results) <= 2
def test_no_match_returns_empty(self):
"""无匹配时返回空列表"""
tools = _make_tools()
index = ToolSearchIndex(tools)
results = index.search("xyzzy_nonexistent")
assert results == []
def test_empty_query_returns_empty(self):
"""空查询返回空列表"""
tools = _make_tools()
index = ToolSearchIndex(tools)
assert index.search("") == []
assert index.search(" ") == []
def test_top_k_zero_returns_empty(self):
"""top_k=0 返回空列表"""
tools = _make_tools()
index = ToolSearchIndex(tools)
assert index.search("file", top_k=0) == []
def test_snake_case_tokenization(self):
"""snake_case 工具名被正确分词"""
tool = FakeTool(
name="read_file",
description="Read file contents.",
)
index = ToolSearchIndex([tool])
# 搜索 "read" 应该匹配
results = index.search("read")
assert len(results) == 1
# 搜索 "file" 也应该匹配
results = index.search("file")
assert len(results) == 1
def test_search_includes_parameter_descriptions(self):
"""搜索能匹配参数描述中的关键词"""
tool = FakeTool(
name="custom_tool",
description="A custom tool.",
input_schema={
"type": "object",
"properties": {
"database_url": {
"type": "string",
"description": "PostgreSQL connection string",
},
},
},
)
index = ToolSearchIndex([tool])
results = index.search("postgresql database")
assert len(results) == 1
assert results[0].name == "custom_tool"
def test_search_includes_tags(self):
"""搜索能匹配标签中的关键词"""
tool = FakeTool(
name="data_tool",
description="Process data.",
tags=["etl", "pipeline"],
)
index = ToolSearchIndex([tool])
results = index.search("pipeline etl")
assert len(results) == 1
assert results[0].name == "data_tool"
def test_invalid_k1_raises(self):
"""k1 < 0 抛出 ValueError"""
with pytest.raises(ValueError, match="k1"):
ToolSearchIndex(_make_tools(), k1=-1.0)
def test_invalid_b_raises(self):
"""b 不在 [0,1] 范围抛出 ValueError"""
with pytest.raises(ValueError, match="b"):
ToolSearchIndex(_make_tools(), b=1.5)
with pytest.raises(ValueError, match="b"):
ToolSearchIndex(_make_tools(), b=-0.1)
def test_multiple_results_sorted_by_score(self):
"""多个匹配结果按分数降序排列"""
tools = [
FakeTool(name="search_web", description="Search the web."),
FakeTool(name="search_files", description="Search files on disk."),
FakeTool(name="unrelated", description="Do something unrelated."),
]
index = ToolSearchIndex(tools)
results = index.search("search")
# 两个包含 "search" 的工具应该返回unrelated 不返回
assert len(results) == 2
names = [r.name for r in results]
assert "unrelated" not in names
# ── ToolSearchTool Tests ──────────────────────────────────
class TestToolSearchTool:
"""ToolSearchTool 工具测试"""
def test_tool_name_and_schema(self):
"""工具名称和 schema 正确"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
assert tool.name == "tool_search"
assert "query" in tool.input_schema["properties"]
assert "query" in tool.input_schema["required"]
async def test_execute_returns_results(self):
"""执行搜索返回匹配工具的完整描述"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
result = await tool.execute(query="web search")
assert result["count"] > 0
assert result["query"] == "web search"
first = result["results"][0]
assert "name" in first
assert "description" in first
assert "parameters" in first
assert first["name"] == "web_search"
async def test_execute_empty_query_returns_error(self):
"""空查询返回错误"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
result = await tool.execute(query="")
assert "error" in result
assert result["results"] == []
async def test_execute_no_match(self):
"""无匹配返回空结果和提示消息"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index)
result = await tool.execute(query="zzz_nonexistent")
assert result["count"] == 0
assert result["results"] == []
assert "message" in result
async def test_execute_respects_top_k(self):
"""top_k 限制返回数量"""
index = ToolSearchIndex(_make_tools())
tool = ToolSearchTool(search_index=index, top_k=1)
result = await tool.execute(query="file")
assert result["count"] <= 1
def test_invalid_top_k_raises(self):
"""top_k < 1 抛出 ValueError"""
index = ToolSearchIndex(_make_tools())
with pytest.raises(ValueError, match="top_k"):
ToolSearchTool(search_index=index, top_k=0)
# ── ReActEngine Tiered Injection Tests ────────────────────
class TestReActTieredInjection:
"""ReActEngine 工具描述分层注入测试"""
def _make_engine(self, **kwargs: Any):
from agentkit.core.react import ReActEngine
gateway = MagicMock()
return ReActEngine(llm_gateway=gateway, **kwargs)
def test_core_tools_get_full_description(self):
"""Core 工具注入完整描述(含参数)"""
engine = self._make_engine()
tools = [
FakeTool(
name="read_file",
description="Read a file.",
input_schema={
"type": "object",
"properties": {
"path": {"type": "string", "description": "file path"},
},
"required": ["path"],
},
),
]
prompt = engine._build_tool_use_prompt(tools)
# 核心工具区域存在
assert "核心工具" in prompt
# 参数描述被注入
assert "path" in prompt
assert "file path" in prompt
def test_extended_tools_get_one_line_only(self):
"""Extended 工具只注入名称+一行描述(无参数)"""
engine = self._make_engine()
tools = [
FakeTool(
name="custom_extended",
description="A custom extended tool for testing.",
input_schema={
"type": "object",
"properties": {
"secret_param": {
"type": "string",
"description": "SECRET_PARAM_DESC",
},
},
},
),
]
prompt = engine._build_tool_use_prompt(tools)
assert "扩展工具" in prompt
assert "custom_extended" in prompt
# 参数描述不应出现在扩展工具区域
assert "SECRET_PARAM_DESC" not in prompt
def test_mixed_core_and_extended(self):
"""混合 core + extended 工具,两者分区显示"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(
name="web_search",
description="Search the web.",
input_schema={
"type": "object",
"properties": {
"q": {"type": "string", "description": "query string"},
},
},
),
]
prompt = engine._build_tool_use_prompt(tools)
assert "核心工具" in prompt
assert "扩展工具" in prompt
# read_file 在核心区web_search 在扩展区
assert "read_file" in prompt
assert "web_search" in prompt
def test_maybe_add_tool_search_adds_for_extended(self):
"""有扩展工具时自动添加 tool_search"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."), # core
FakeTool(name="web_search", description="Search the web."), # extended
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 3
assert any(t.name == "tool_search" for t in result)
def test_maybe_add_tool_search_skips_when_only_core(self):
"""只有 core 工具时不添加 tool_search"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="str_replace_editor", description="Edit a file."),
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 2
assert not any(t.name == "tool_search" for t in result)
def test_maybe_add_tool_search_skips_when_disabled(self):
"""enable_tool_search=False 时不添加 tool_search"""
engine = self._make_engine(enable_tool_search=False)
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="web_search", description="Search the web."),
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 2
assert not any(t.name == "tool_search" for t in result)
def test_maybe_add_tool_search_skips_when_already_present(self):
"""tool_search 已存在时不重复添加"""
engine = self._make_engine()
index = ToolSearchIndex([])
existing_search = ToolSearchTool(search_index=index)
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="web_search", description="Search the web."),
existing_search,
]
result = engine._maybe_add_tool_search(tools)
assert len(result) == 3
def test_custom_core_tool_names(self):
"""自定义 core_tool_names 覆盖默认值"""
engine = self._make_engine(core_tool_names=["my_core_tool"])
tools = [
FakeTool(name="my_core_tool", description="My core tool."),
FakeTool(
name="read_file",
description="Read a file.",
input_schema={
"type": "object",
"properties": {
"p": {"type": "string", "description": "PARAM_DESC"},
},
},
),
]
prompt = engine._build_tool_use_prompt(tools)
# my_core_tool 在核心区
assert "核心工具" in prompt
# read_file 现在是扩展工具,参数描述不应出现
assert "PARAM_DESC" not in prompt
def test_tool_search_is_core_tool(self):
"""tool_search 被视为 core 工具(全量描述注入)"""
engine = self._make_engine()
index = ToolSearchIndex([])
search_tool = ToolSearchTool(search_index=index)
tools = [search_tool]
prompt = engine._build_tool_use_prompt(tools)
# tool_search 应该在核心工具区
assert "核心工具" in prompt
assert "tool_search" in prompt
# 其参数 query 应该被注入
assert "query" in prompt
def test_search_hint_added_when_tool_search_present(self):
"""tool_search 存在且有扩展工具时添加搜索提示"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
FakeTool(name="web_search", description="Search the web."),
]
# 模拟 _maybe_add_tool_search 添加 tool_search
tools_with_search = engine._maybe_add_tool_search(tools)
prompt = engine._build_tool_use_prompt(tools_with_search)
assert "tool_search" in prompt
assert "扩展工具" in prompt
# 搜索提示存在
assert "tool_search(query" in prompt
def test_no_search_hint_when_no_extended_tools(self):
"""无扩展工具时不添加搜索提示"""
engine = self._make_engine()
tools = [
FakeTool(name="read_file", description="Read a file."),
]
prompt = engine._build_tool_use_prompt(tools)
assert "扩展工具" not in prompt