105 lines
2.9 KiB
Python
105 lines
2.9 KiB
Python
"""Tests for Tool system"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
|
|
from agentkit.tools.base import Tool
|
|
from agentkit.tools.function_tool import FunctionTool
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
|
|
async def add_numbers(a: int, b: int) -> dict:
|
|
return {"sum": a + b}
|
|
|
|
|
|
def sync_greet(name: str) -> dict:
|
|
return {"greeting": f"Hello, {name}!"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_tool_async():
|
|
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
|
|
result = await tool.execute(a=1, b=2)
|
|
assert result == {"sum": 3}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_tool_sync():
|
|
tool = FunctionTool(name="greet", description="Greet someone", func=sync_greet)
|
|
result = await tool.execute(name="World")
|
|
assert result == {"greeting": "Hello, World!"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_function_tool_schema_inference():
|
|
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
|
|
assert tool.input_schema is not None
|
|
assert "a" in tool.input_schema.get("properties", {})
|
|
assert "b" in tool.input_schema.get("properties", {})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_registry():
|
|
registry = ToolRegistry()
|
|
tool = FunctionTool(name="add", description="Add numbers", func=add_numbers)
|
|
registry.register(tool)
|
|
|
|
assert registry.has_tool("add")
|
|
retrieved = registry.get("add")
|
|
assert retrieved.name == "add"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_registry_versioning():
|
|
registry = ToolRegistry()
|
|
|
|
v1 = FunctionTool(name="add", description="Add v1", func=add_numbers, version="1.0.0")
|
|
v2 = FunctionTool(name="add", description="Add v2", func=add_numbers, version="2.0.0")
|
|
|
|
registry.register(v1)
|
|
registry.register(v2)
|
|
|
|
# Default returns latest
|
|
latest = registry.get("add")
|
|
assert latest.version == "2.0.0"
|
|
|
|
# Can request specific version
|
|
specific = registry.get("add", version="1.0.0")
|
|
assert specific.version == "1.0.0"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_registry_list():
|
|
registry = ToolRegistry()
|
|
|
|
t1 = FunctionTool(name="add", description="Add", func=add_numbers, tags=["math"])
|
|
t2 = FunctionTool(name="greet", description="Greet", func=sync_greet, tags=["text"])
|
|
|
|
registry.register(t1)
|
|
registry.register(t2)
|
|
|
|
all_tools = registry.list_tools()
|
|
assert len(all_tools) == 2
|
|
|
|
math_tools = registry.list_tools(tag="math")
|
|
assert len(math_tools) == 1
|
|
assert math_tools[0].name == "add"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_safe_execute():
|
|
async def failing_tool():
|
|
raise RuntimeError("boom")
|
|
|
|
tool = FunctionTool(name="fail", description="Always fails", func=failing_tool)
|
|
with pytest.raises(RuntimeError):
|
|
await tool.safe_execute()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_not_found():
|
|
registry = ToolRegistry()
|
|
from agentkit.core.exceptions import ToolNotFoundError
|
|
with pytest.raises(ToolNotFoundError):
|
|
registry.get("nonexistent")
|