262 lines
8.4 KiB
Python
262 lines
8.4 KiB
Python
"""Tests for AgentTool - 将 Agent 包装为 Tool"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
from agentkit.tools.agent_tool import AgentTool
|
|
from agentkit.core.protocol import TaskStatus
|
|
|
|
|
|
class TestAgentToolInit:
|
|
"""AgentTool 初始化测试"""
|
|
|
|
def test_default_attributes(self):
|
|
tool = AgentTool(
|
|
name="my_agent_tool",
|
|
description="Wraps an agent",
|
|
agent_name="target_agent",
|
|
task_type="analyze",
|
|
)
|
|
assert tool.name == "my_agent_tool"
|
|
assert tool.description == "Wraps an agent"
|
|
assert tool.agent_name == "target_agent"
|
|
assert tool.task_type == "analyze"
|
|
assert tool.input_mapping == {}
|
|
assert tool.output_mapping == {}
|
|
assert tool.timeout_seconds == 300
|
|
assert tool.version == "1.0.0"
|
|
assert tool.tags == ["agent"]
|
|
assert tool._dispatcher is None
|
|
|
|
def test_custom_attributes(self):
|
|
tool = AgentTool(
|
|
name="tool",
|
|
description="desc",
|
|
agent_name="agent_a",
|
|
task_type="translate",
|
|
input_mapping={"text": "content"},
|
|
output_mapping={"result": "translation"},
|
|
timeout_seconds=60,
|
|
version="2.0.0",
|
|
tags=["agent", "nlp"],
|
|
)
|
|
assert tool.input_mapping == {"text": "content"}
|
|
assert tool.output_mapping == {"result": "translation"}
|
|
assert tool.timeout_seconds == 60
|
|
assert tool.version == "2.0.0"
|
|
assert tool.tags == ["agent", "nlp"]
|
|
|
|
def test_set_dispatcher_returns_self(self):
|
|
tool = AgentTool(
|
|
name="t", description="d", agent_name="a", task_type="t"
|
|
)
|
|
dispatcher = MagicMock()
|
|
result = tool.set_dispatcher(dispatcher)
|
|
assert result is tool
|
|
assert tool._dispatcher is dispatcher
|
|
|
|
|
|
class TestAgentToolExecute:
|
|
"""AgentTool.execute 异步执行测试"""
|
|
|
|
async def test_execute_without_dispatcher_raises(self):
|
|
tool = AgentTool(
|
|
name="t", description="d", agent_name="a", task_type="t"
|
|
)
|
|
with pytest.raises(RuntimeError, match="has no dispatcher configured"):
|
|
await tool.execute(query="hello")
|
|
|
|
async def test_execute_dispatches_task(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "completed",
|
|
"output_data": {"answer": "world"},
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t", description="d", agent_name="target", task_type="ask"
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
result = await tool.execute(query="hello")
|
|
|
|
assert result == {"answer": "world"}
|
|
dispatcher.dispatch.assert_awaited_once()
|
|
dispatched_task = dispatcher.dispatch.call_args[0][0]
|
|
assert dispatched_task.agent_name == "target"
|
|
assert dispatched_task.task_type == "ask"
|
|
|
|
async def test_execute_with_input_mapping(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "completed",
|
|
"output_data": {"text": "result"},
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t",
|
|
description="d",
|
|
agent_name="a",
|
|
task_type="t",
|
|
input_mapping={"content": "query"},
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
await tool.execute(query="hello")
|
|
|
|
dispatched_task = dispatcher.dispatch.call_args[0][0]
|
|
assert dispatched_task.input_data == {"content": "hello"}
|
|
|
|
async def test_execute_without_input_mapping_passes_all_kwargs(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "completed",
|
|
"output_data": {},
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t", description="d", agent_name="a", task_type="t"
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
await tool.execute(x=1, y=2)
|
|
|
|
dispatched_task = dispatcher.dispatch.call_args[0][0]
|
|
assert dispatched_task.input_data == {"x": 1, "y": 2}
|
|
|
|
async def test_execute_with_output_mapping(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "completed",
|
|
"output_data": {"translation": "bonjour", "confidence": 0.9},
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t",
|
|
description="d",
|
|
agent_name="a",
|
|
task_type="t",
|
|
output_mapping={"result": "translation"},
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
result = await tool.execute(text="hello")
|
|
|
|
assert result == {"result": "bonjour"}
|
|
|
|
async def test_execute_output_mapping_skips_missing_keys(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "completed",
|
|
"output_data": {"translation": "bonjour"},
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t",
|
|
description="d",
|
|
agent_name="a",
|
|
task_type="t",
|
|
output_mapping={"result": "translation", "score": "confidence"},
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
result = await tool.execute(text="hello")
|
|
|
|
assert result == {"result": "bonjour"}
|
|
|
|
async def test_execute_failed_status_raises(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "failed",
|
|
"error_message": "OOM",
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t", description="d", agent_name="a", task_type="t"
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
with pytest.raises(RuntimeError, match="failed: OOM"):
|
|
await tool.execute()
|
|
|
|
async def test_execute_cancelled_returns_empty(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "cancelled",
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t", description="d", agent_name="a", task_type="t"
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
result = await tool.execute()
|
|
assert result == {}
|
|
|
|
async def test_execute_completed_no_output_data_returns_empty(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "completed",
|
|
"output_data": None,
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t", description="d", agent_name="a", task_type="t"
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
result = await tool.execute()
|
|
assert result == {}
|
|
|
|
async def test_execute_timeout_raises(self):
|
|
dispatcher = AsyncMock()
|
|
# Always return running status to simulate timeout
|
|
dispatcher.get_task_status.return_value = {"status": "running"}
|
|
|
|
tool = AgentTool(
|
|
name="t",
|
|
description="d",
|
|
agent_name="a",
|
|
task_type="t",
|
|
timeout_seconds=1,
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
with pytest.raises(TimeoutError, match="timed out after 1s"):
|
|
await tool.execute()
|
|
|
|
async def test_execute_waits_for_completion(self):
|
|
dispatcher = AsyncMock()
|
|
call_count = 0
|
|
|
|
async def mock_status(task_id):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count < 3:
|
|
return {"status": "running"}
|
|
return {"status": "completed", "output_data": {"done": True}}
|
|
|
|
dispatcher.get_task_status.side_effect = mock_status
|
|
|
|
tool = AgentTool(
|
|
name="t",
|
|
description="d",
|
|
agent_name="a",
|
|
task_type="t",
|
|
timeout_seconds=10,
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
result = await tool.execute()
|
|
assert result == {"done": True}
|
|
|
|
async def test_execute_input_mapping_only_maps_matched_keys(self):
|
|
dispatcher = AsyncMock()
|
|
dispatcher.get_task_status.return_value = {
|
|
"status": "completed",
|
|
"output_data": {},
|
|
}
|
|
|
|
tool = AgentTool(
|
|
name="t",
|
|
description="d",
|
|
agent_name="a",
|
|
task_type="t",
|
|
input_mapping={"content": "query", "extra": "missing_key"},
|
|
)
|
|
tool.set_dispatcher(dispatcher)
|
|
await tool.execute(query="hello", other="world")
|
|
|
|
dispatched_task = dispatcher.dispatch.call_args[0][0]
|
|
assert dispatched_task.input_data == {"content": "hello"}
|