fischer-agentkit/tests/unit/test_agent_tool.py

262 lines
8.4 KiB
Python

"""Tests for AgentTool - 将 Agent 包装为 Tool"""
import pytest
from unittest.mock import AsyncMock, MagicMock
from agentkit.tools.agent_tool import AgentTool
from agentkit.core.protocol import TaskStatus
class TestAgentToolInit:
"""AgentTool 初始化测试"""
def test_default_attributes(self):
tool = AgentTool(
name="my_agent_tool",
description="Wraps an agent",
agent_name="target_agent",
task_type="analyze",
)
assert tool.name == "my_agent_tool"
assert tool.description == "Wraps an agent"
assert tool.agent_name == "target_agent"
assert tool.task_type == "analyze"
assert tool.input_mapping == {}
assert tool.output_mapping == {}
assert tool.timeout_seconds == 300
assert tool.version == "1.0.0"
assert tool.tags == ["agent"]
assert tool._dispatcher is None
def test_custom_attributes(self):
tool = AgentTool(
name="tool",
description="desc",
agent_name="agent_a",
task_type="translate",
input_mapping={"text": "content"},
output_mapping={"result": "translation"},
timeout_seconds=60,
version="2.0.0",
tags=["agent", "nlp"],
)
assert tool.input_mapping == {"text": "content"}
assert tool.output_mapping == {"result": "translation"}
assert tool.timeout_seconds == 60
assert tool.version == "2.0.0"
assert tool.tags == ["agent", "nlp"]
def test_set_dispatcher_returns_self(self):
tool = AgentTool(
name="t", description="d", agent_name="a", task_type="t"
)
dispatcher = MagicMock()
result = tool.set_dispatcher(dispatcher)
assert result is tool
assert tool._dispatcher is dispatcher
class TestAgentToolExecute:
"""AgentTool.execute 异步执行测试"""
async def test_execute_without_dispatcher_raises(self):
tool = AgentTool(
name="t", description="d", agent_name="a", task_type="t"
)
with pytest.raises(RuntimeError, match="has no dispatcher configured"):
await tool.execute(query="hello")
async def test_execute_dispatches_task(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "completed",
"output_data": {"answer": "world"},
}
tool = AgentTool(
name="t", description="d", agent_name="target", task_type="ask"
)
tool.set_dispatcher(dispatcher)
result = await tool.execute(query="hello")
assert result == {"answer": "world"}
dispatcher.dispatch.assert_awaited_once()
dispatched_task = dispatcher.dispatch.call_args[0][0]
assert dispatched_task.agent_name == "target"
assert dispatched_task.task_type == "ask"
async def test_execute_with_input_mapping(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "completed",
"output_data": {"text": "result"},
}
tool = AgentTool(
name="t",
description="d",
agent_name="a",
task_type="t",
input_mapping={"content": "query"},
)
tool.set_dispatcher(dispatcher)
await tool.execute(query="hello")
dispatched_task = dispatcher.dispatch.call_args[0][0]
assert dispatched_task.input_data == {"content": "hello"}
async def test_execute_without_input_mapping_passes_all_kwargs(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "completed",
"output_data": {},
}
tool = AgentTool(
name="t", description="d", agent_name="a", task_type="t"
)
tool.set_dispatcher(dispatcher)
await tool.execute(x=1, y=2)
dispatched_task = dispatcher.dispatch.call_args[0][0]
assert dispatched_task.input_data == {"x": 1, "y": 2}
async def test_execute_with_output_mapping(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "completed",
"output_data": {"translation": "bonjour", "confidence": 0.9},
}
tool = AgentTool(
name="t",
description="d",
agent_name="a",
task_type="t",
output_mapping={"result": "translation"},
)
tool.set_dispatcher(dispatcher)
result = await tool.execute(text="hello")
assert result == {"result": "bonjour"}
async def test_execute_output_mapping_skips_missing_keys(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "completed",
"output_data": {"translation": "bonjour"},
}
tool = AgentTool(
name="t",
description="d",
agent_name="a",
task_type="t",
output_mapping={"result": "translation", "score": "confidence"},
)
tool.set_dispatcher(dispatcher)
result = await tool.execute(text="hello")
assert result == {"result": "bonjour"}
async def test_execute_failed_status_raises(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "failed",
"error_message": "OOM",
}
tool = AgentTool(
name="t", description="d", agent_name="a", task_type="t"
)
tool.set_dispatcher(dispatcher)
with pytest.raises(RuntimeError, match="failed: OOM"):
await tool.execute()
async def test_execute_cancelled_returns_empty(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "cancelled",
}
tool = AgentTool(
name="t", description="d", agent_name="a", task_type="t"
)
tool.set_dispatcher(dispatcher)
result = await tool.execute()
assert result == {}
async def test_execute_completed_no_output_data_returns_empty(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "completed",
"output_data": None,
}
tool = AgentTool(
name="t", description="d", agent_name="a", task_type="t"
)
tool.set_dispatcher(dispatcher)
result = await tool.execute()
assert result == {}
async def test_execute_timeout_raises(self):
dispatcher = AsyncMock()
# Always return running status to simulate timeout
dispatcher.get_task_status.return_value = {"status": "running"}
tool = AgentTool(
name="t",
description="d",
agent_name="a",
task_type="t",
timeout_seconds=1,
)
tool.set_dispatcher(dispatcher)
with pytest.raises(TimeoutError, match="timed out after 1s"):
await tool.execute()
async def test_execute_waits_for_completion(self):
dispatcher = AsyncMock()
call_count = 0
async def mock_status(task_id):
nonlocal call_count
call_count += 1
if call_count < 3:
return {"status": "running"}
return {"status": "completed", "output_data": {"done": True}}
dispatcher.get_task_status.side_effect = mock_status
tool = AgentTool(
name="t",
description="d",
agent_name="a",
task_type="t",
timeout_seconds=10,
)
tool.set_dispatcher(dispatcher)
result = await tool.execute()
assert result == {"done": True}
async def test_execute_input_mapping_only_maps_matched_keys(self):
dispatcher = AsyncMock()
dispatcher.get_task_status.return_value = {
"status": "completed",
"output_data": {},
}
tool = AgentTool(
name="t",
description="d",
agent_name="a",
task_type="t",
input_mapping={"content": "query", "extra": "missing_key"},
)
tool.set_dispatcher(dispatcher)
await tool.execute(query="hello", other="world")
dispatched_task = dispatcher.dispatch.call_args[0][0]
assert dispatched_task.input_data == {"content": "hello"}