473 lines
17 KiB
Python
473 lines
17 KiB
Python
"""ToolSearchIndex / ToolSearchTool / ReAct 分层注入单元测试
|
||
|
||
测试场景:
|
||
- ToolSearchIndex: 空索引、单工具、多工具相关性排序、top_k 限制、无匹配
|
||
- ToolSearchTool: 正常查询、空查询、无匹配、结果包含完整描述
|
||
- ReActEngine 分层注入: core/extended 分离、tool_search 自动添加、禁用配置、自定义 core 列表
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from typing import Any
|
||
from unittest.mock import MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.tools.base import Tool
|
||
from agentkit.tools.builtin import ToolSearchTool
|
||
from agentkit.tools.search import ToolSearchIndex
|
||
|
||
|
||
# ── Test Helpers ──────────────────────────────────────────
|
||
|
||
|
||
class FakeTool(Tool):
|
||
"""用于测试的 Fake Tool"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str,
|
||
description: str,
|
||
input_schema: dict[str, Any] | None = None,
|
||
tags: list[str] | None = None,
|
||
):
|
||
super().__init__(
|
||
name=name,
|
||
description=description,
|
||
input_schema=input_schema,
|
||
tags=tags or [],
|
||
)
|
||
|
||
async def execute(self, **kwargs) -> dict:
|
||
return {"status": "ok"}
|
||
|
||
|
||
def _make_tools() -> list[Tool]:
|
||
"""创建一组测试工具"""
|
||
return [
|
||
FakeTool(
|
||
name="read_file",
|
||
description="Read the contents of a file from the filesystem.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"path": {"type": "string", "description": "file path to read"},
|
||
},
|
||
"required": ["path"],
|
||
},
|
||
tags=["io", "file"],
|
||
),
|
||
FakeTool(
|
||
name="write_file",
|
||
description="Write content to a file on the filesystem.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"path": {"type": "string", "description": "file path to write"},
|
||
"content": {"type": "string", "description": "content to write"},
|
||
},
|
||
"required": ["path", "content"],
|
||
},
|
||
tags=["io", "file"],
|
||
),
|
||
FakeTool(
|
||
name="web_search",
|
||
description="Search the web for information using a search engine.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string", "description": "search query"},
|
||
},
|
||
"required": ["query"],
|
||
},
|
||
tags=["web", "search"],
|
||
),
|
||
FakeTool(
|
||
name="run_tests",
|
||
description="Run project tests to verify code changes.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"commands": {"type": "array", "description": "test commands"},
|
||
},
|
||
},
|
||
tags=["testing", "verification"],
|
||
),
|
||
]
|
||
|
||
|
||
# ── ToolSearchIndex Tests ─────────────────────────────────
|
||
|
||
|
||
class TestToolSearchIndex:
|
||
"""ToolSearchIndex BM25 搜索测试"""
|
||
|
||
def test_empty_tools(self):
|
||
"""空工具列表构建索引不报错"""
|
||
index = ToolSearchIndex([])
|
||
assert len(index) == 0
|
||
assert index.search("anything") == []
|
||
|
||
def test_single_tool_match(self):
|
||
"""单工具索引,匹配查询返回该工具"""
|
||
tools = _make_tools()[:1]
|
||
index = ToolSearchIndex(tools)
|
||
results = index.search("read file")
|
||
assert len(results) == 1
|
||
assert results[0].name == "read_file"
|
||
|
||
def test_relevance_ranking(self):
|
||
"""多工具索引,相关工具排在前面"""
|
||
tools = _make_tools()
|
||
index = ToolSearchIndex(tools)
|
||
results = index.search("web search")
|
||
assert len(results) > 0
|
||
# web_search 应该排在最前
|
||
assert results[0].name == "web_search"
|
||
|
||
def test_top_k_limit(self):
|
||
"""top_k 限制返回数量"""
|
||
tools = _make_tools()
|
||
index = ToolSearchIndex(tools)
|
||
results = index.search("file", top_k=2)
|
||
assert len(results) <= 2
|
||
|
||
def test_no_match_returns_empty(self):
|
||
"""无匹配时返回空列表"""
|
||
tools = _make_tools()
|
||
index = ToolSearchIndex(tools)
|
||
results = index.search("xyzzy_nonexistent")
|
||
assert results == []
|
||
|
||
def test_empty_query_returns_empty(self):
|
||
"""空查询返回空列表"""
|
||
tools = _make_tools()
|
||
index = ToolSearchIndex(tools)
|
||
assert index.search("") == []
|
||
assert index.search(" ") == []
|
||
|
||
def test_top_k_zero_returns_empty(self):
|
||
"""top_k=0 返回空列表"""
|
||
tools = _make_tools()
|
||
index = ToolSearchIndex(tools)
|
||
assert index.search("file", top_k=0) == []
|
||
|
||
def test_snake_case_tokenization(self):
|
||
"""snake_case 工具名被正确分词"""
|
||
tool = FakeTool(
|
||
name="read_file",
|
||
description="Read file contents.",
|
||
)
|
||
index = ToolSearchIndex([tool])
|
||
# 搜索 "read" 应该匹配
|
||
results = index.search("read")
|
||
assert len(results) == 1
|
||
# 搜索 "file" 也应该匹配
|
||
results = index.search("file")
|
||
assert len(results) == 1
|
||
|
||
def test_search_includes_parameter_descriptions(self):
|
||
"""搜索能匹配参数描述中的关键词"""
|
||
tool = FakeTool(
|
||
name="custom_tool",
|
||
description="A custom tool.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"database_url": {
|
||
"type": "string",
|
||
"description": "PostgreSQL connection string",
|
||
},
|
||
},
|
||
},
|
||
)
|
||
index = ToolSearchIndex([tool])
|
||
results = index.search("postgresql database")
|
||
assert len(results) == 1
|
||
assert results[0].name == "custom_tool"
|
||
|
||
def test_search_includes_tags(self):
|
||
"""搜索能匹配标签中的关键词"""
|
||
tool = FakeTool(
|
||
name="data_tool",
|
||
description="Process data.",
|
||
tags=["etl", "pipeline"],
|
||
)
|
||
index = ToolSearchIndex([tool])
|
||
results = index.search("pipeline etl")
|
||
assert len(results) == 1
|
||
assert results[0].name == "data_tool"
|
||
|
||
def test_invalid_k1_raises(self):
|
||
"""k1 < 0 抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="k1"):
|
||
ToolSearchIndex(_make_tools(), k1=-1.0)
|
||
|
||
def test_invalid_b_raises(self):
|
||
"""b 不在 [0,1] 范围抛出 ValueError"""
|
||
with pytest.raises(ValueError, match="b"):
|
||
ToolSearchIndex(_make_tools(), b=1.5)
|
||
with pytest.raises(ValueError, match="b"):
|
||
ToolSearchIndex(_make_tools(), b=-0.1)
|
||
|
||
def test_multiple_results_sorted_by_score(self):
|
||
"""多个匹配结果按分数降序排列"""
|
||
tools = [
|
||
FakeTool(name="search_web", description="Search the web."),
|
||
FakeTool(name="search_files", description="Search files on disk."),
|
||
FakeTool(name="unrelated", description="Do something unrelated."),
|
||
]
|
||
index = ToolSearchIndex(tools)
|
||
results = index.search("search")
|
||
# 两个包含 "search" 的工具应该返回,unrelated 不返回
|
||
assert len(results) == 2
|
||
names = [r.name for r in results]
|
||
assert "unrelated" not in names
|
||
|
||
|
||
# ── ToolSearchTool Tests ──────────────────────────────────
|
||
|
||
|
||
class TestToolSearchTool:
|
||
"""ToolSearchTool 工具测试"""
|
||
|
||
def test_tool_name_and_schema(self):
|
||
"""工具名称和 schema 正确"""
|
||
index = ToolSearchIndex(_make_tools())
|
||
tool = ToolSearchTool(search_index=index)
|
||
assert tool.name == "tool_search"
|
||
assert "query" in tool.input_schema["properties"]
|
||
assert "query" in tool.input_schema["required"]
|
||
|
||
async def test_execute_returns_results(self):
|
||
"""执行搜索返回匹配工具的完整描述"""
|
||
index = ToolSearchIndex(_make_tools())
|
||
tool = ToolSearchTool(search_index=index)
|
||
result = await tool.execute(query="web search")
|
||
|
||
assert result["count"] > 0
|
||
assert result["query"] == "web search"
|
||
first = result["results"][0]
|
||
assert "name" in first
|
||
assert "description" in first
|
||
assert "parameters" in first
|
||
assert first["name"] == "web_search"
|
||
|
||
async def test_execute_empty_query_returns_error(self):
|
||
"""空查询返回错误"""
|
||
index = ToolSearchIndex(_make_tools())
|
||
tool = ToolSearchTool(search_index=index)
|
||
result = await tool.execute(query="")
|
||
assert "error" in result
|
||
assert result["results"] == []
|
||
|
||
async def test_execute_no_match(self):
|
||
"""无匹配返回空结果和提示消息"""
|
||
index = ToolSearchIndex(_make_tools())
|
||
tool = ToolSearchTool(search_index=index)
|
||
result = await tool.execute(query="zzz_nonexistent")
|
||
assert result["count"] == 0
|
||
assert result["results"] == []
|
||
assert "message" in result
|
||
|
||
async def test_execute_respects_top_k(self):
|
||
"""top_k 限制返回数量"""
|
||
index = ToolSearchIndex(_make_tools())
|
||
tool = ToolSearchTool(search_index=index, top_k=1)
|
||
result = await tool.execute(query="file")
|
||
assert result["count"] <= 1
|
||
|
||
def test_invalid_top_k_raises(self):
|
||
"""top_k < 1 抛出 ValueError"""
|
||
index = ToolSearchIndex(_make_tools())
|
||
with pytest.raises(ValueError, match="top_k"):
|
||
ToolSearchTool(search_index=index, top_k=0)
|
||
|
||
|
||
# ── ReActEngine Tiered Injection Tests ────────────────────
|
||
|
||
|
||
class TestReActTieredInjection:
|
||
"""ReActEngine 工具描述分层注入测试"""
|
||
|
||
def _make_engine(self, **kwargs: Any):
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
gateway = MagicMock()
|
||
return ReActEngine(llm_gateway=gateway, **kwargs)
|
||
|
||
def test_core_tools_get_full_description(self):
|
||
"""Core 工具注入完整描述(含参数)"""
|
||
engine = self._make_engine()
|
||
tools = [
|
||
FakeTool(
|
||
name="read_file",
|
||
description="Read a file.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"path": {"type": "string", "description": "file path"},
|
||
},
|
||
"required": ["path"],
|
||
},
|
||
),
|
||
]
|
||
prompt = engine._build_tool_use_prompt(tools)
|
||
# 核心工具区域存在
|
||
assert "核心工具" in prompt
|
||
# 参数描述被注入
|
||
assert "path" in prompt
|
||
assert "file path" in prompt
|
||
|
||
def test_extended_tools_get_one_line_only(self):
|
||
"""Extended 工具只注入名称+一行描述(无参数)"""
|
||
engine = self._make_engine()
|
||
tools = [
|
||
FakeTool(
|
||
name="custom_extended",
|
||
description="A custom extended tool for testing.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"secret_param": {
|
||
"type": "string",
|
||
"description": "SECRET_PARAM_DESC",
|
||
},
|
||
},
|
||
},
|
||
),
|
||
]
|
||
prompt = engine._build_tool_use_prompt(tools)
|
||
assert "扩展工具" in prompt
|
||
assert "custom_extended" in prompt
|
||
# 参数描述不应出现在扩展工具区域
|
||
assert "SECRET_PARAM_DESC" not in prompt
|
||
|
||
def test_mixed_core_and_extended(self):
|
||
"""混合 core + extended 工具,两者分区显示"""
|
||
engine = self._make_engine()
|
||
tools = [
|
||
FakeTool(name="read_file", description="Read a file."),
|
||
FakeTool(
|
||
name="web_search",
|
||
description="Search the web.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"q": {"type": "string", "description": "query string"},
|
||
},
|
||
},
|
||
),
|
||
]
|
||
prompt = engine._build_tool_use_prompt(tools)
|
||
assert "核心工具" in prompt
|
||
assert "扩展工具" in prompt
|
||
# read_file 在核心区,web_search 在扩展区
|
||
assert "read_file" in prompt
|
||
assert "web_search" in prompt
|
||
|
||
def test_maybe_add_tool_search_adds_for_extended(self):
|
||
"""有扩展工具时自动添加 tool_search"""
|
||
engine = self._make_engine()
|
||
tools = [
|
||
FakeTool(name="read_file", description="Read a file."), # core
|
||
FakeTool(name="web_search", description="Search the web."), # extended
|
||
]
|
||
result = engine._maybe_add_tool_search(tools)
|
||
assert len(result) == 3
|
||
assert any(t.name == "tool_search" for t in result)
|
||
|
||
def test_maybe_add_tool_search_skips_when_only_core(self):
|
||
"""只有 core 工具时不添加 tool_search"""
|
||
engine = self._make_engine()
|
||
tools = [
|
||
FakeTool(name="read_file", description="Read a file."),
|
||
FakeTool(name="str_replace_editor", description="Edit a file."),
|
||
]
|
||
result = engine._maybe_add_tool_search(tools)
|
||
assert len(result) == 2
|
||
assert not any(t.name == "tool_search" for t in result)
|
||
|
||
def test_maybe_add_tool_search_skips_when_disabled(self):
|
||
"""enable_tool_search=False 时不添加 tool_search"""
|
||
engine = self._make_engine(enable_tool_search=False)
|
||
tools = [
|
||
FakeTool(name="read_file", description="Read a file."),
|
||
FakeTool(name="web_search", description="Search the web."),
|
||
]
|
||
result = engine._maybe_add_tool_search(tools)
|
||
assert len(result) == 2
|
||
assert not any(t.name == "tool_search" for t in result)
|
||
|
||
def test_maybe_add_tool_search_skips_when_already_present(self):
|
||
"""tool_search 已存在时不重复添加"""
|
||
engine = self._make_engine()
|
||
index = ToolSearchIndex([])
|
||
existing_search = ToolSearchTool(search_index=index)
|
||
tools = [
|
||
FakeTool(name="read_file", description="Read a file."),
|
||
FakeTool(name="web_search", description="Search the web."),
|
||
existing_search,
|
||
]
|
||
result = engine._maybe_add_tool_search(tools)
|
||
assert len(result) == 3
|
||
|
||
def test_custom_core_tool_names(self):
|
||
"""自定义 core_tool_names 覆盖默认值"""
|
||
engine = self._make_engine(core_tool_names=["my_core_tool"])
|
||
tools = [
|
||
FakeTool(name="my_core_tool", description="My core tool."),
|
||
FakeTool(
|
||
name="read_file",
|
||
description="Read a file.",
|
||
input_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"p": {"type": "string", "description": "PARAM_DESC"},
|
||
},
|
||
},
|
||
),
|
||
]
|
||
prompt = engine._build_tool_use_prompt(tools)
|
||
# my_core_tool 在核心区
|
||
assert "核心工具" in prompt
|
||
# read_file 现在是扩展工具,参数描述不应出现
|
||
assert "PARAM_DESC" not in prompt
|
||
|
||
def test_tool_search_is_core_tool(self):
|
||
"""tool_search 被视为 core 工具(全量描述注入)"""
|
||
engine = self._make_engine()
|
||
index = ToolSearchIndex([])
|
||
search_tool = ToolSearchTool(search_index=index)
|
||
tools = [search_tool]
|
||
prompt = engine._build_tool_use_prompt(tools)
|
||
# tool_search 应该在核心工具区
|
||
assert "核心工具" in prompt
|
||
assert "tool_search" in prompt
|
||
# 其参数 query 应该被注入
|
||
assert "query" in prompt
|
||
|
||
def test_search_hint_added_when_tool_search_present(self):
|
||
"""tool_search 存在且有扩展工具时添加搜索提示"""
|
||
engine = self._make_engine()
|
||
tools = [
|
||
FakeTool(name="read_file", description="Read a file."),
|
||
FakeTool(name="web_search", description="Search the web."),
|
||
]
|
||
# 模拟 _maybe_add_tool_search 添加 tool_search
|
||
tools_with_search = engine._maybe_add_tool_search(tools)
|
||
prompt = engine._build_tool_use_prompt(tools_with_search)
|
||
assert "tool_search" in prompt
|
||
assert "扩展工具" in prompt
|
||
# 搜索提示存在
|
||
assert "tool_search(query" in prompt
|
||
|
||
def test_no_search_hint_when_no_extended_tools(self):
|
||
"""无扩展工具时不添加搜索提示"""
|
||
engine = self._make_engine()
|
||
tools = [
|
||
FakeTool(name="read_file", description="Read a file."),
|
||
]
|
||
prompt = engine._build_tool_use_prompt(tools)
|
||
assert "扩展工具" not in prompt
|