357 lines
11 KiB
Python
357 lines
11 KiB
Python
"""Tests for BaseAgent - 统一生命周期"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
|
|
from agentkit.core.base import BaseAgent
|
|
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
|
|
from agentkit.core.protocol import (
|
|
AgentCapability,
|
|
AgentStatus,
|
|
CancellationToken,
|
|
TaskMessage,
|
|
TaskResult,
|
|
TaskStatus,
|
|
)
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
class SimpleAgent(BaseAgent):
|
|
"""测试用简单 Agent"""
|
|
|
|
def __init__(self):
|
|
super().__init__(name="test_agent", agent_type="test", version="1.0.0")
|
|
self.task_started = False
|
|
self.task_completed = False
|
|
self.task_failed = False
|
|
|
|
async def handle_task(self, task: TaskMessage) -> dict:
|
|
if task.task_type == "echo":
|
|
return {"echo": task.input_data}
|
|
elif task.task_type == "fail":
|
|
raise ValueError("intentional failure")
|
|
elif task.task_type == "slow":
|
|
await asyncio.sleep(10)
|
|
return {"status": "slow_done"}
|
|
return {"status": "ok"}
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["echo", "fail", "slow"],
|
|
max_concurrency=2,
|
|
description="Test agent",
|
|
)
|
|
|
|
async def on_task_start(self, task):
|
|
self.task_started = True
|
|
|
|
async def on_task_complete(self, task, output):
|
|
self.task_completed = True
|
|
|
|
async def on_task_failed(self, task, error):
|
|
self.task_failed = True
|
|
|
|
|
|
def _make_task(task_type: str = "echo", input_data: dict | None = None, timeout_seconds: int = 300) -> TaskMessage:
|
|
return TaskMessage(
|
|
task_id="test-001",
|
|
agent_name="test_agent",
|
|
task_type=task_type,
|
|
priority=0,
|
|
input_data=input_data or {},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
timeout_seconds=timeout_seconds,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_task_returns_output():
|
|
agent = SimpleAgent()
|
|
task = _make_task("echo", {"msg": "hello"})
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert result.output_data == {"echo": {"msg": "hello"}}
|
|
assert result.error_message is None
|
|
assert result.metrics["task_type"] == "echo"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_task_failure():
|
|
agent = SimpleAgent()
|
|
task = _make_task("fail")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.FAILED
|
|
assert result.error_message == "intentional failure"
|
|
assert result.metrics["error_type"] == "ValueError"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lifecycle_hooks():
|
|
agent = SimpleAgent()
|
|
|
|
# 成功任务
|
|
task = _make_task("echo")
|
|
await agent.execute(task)
|
|
assert agent.task_started is True
|
|
assert agent.task_completed is True
|
|
|
|
# 重置
|
|
agent.task_started = False
|
|
agent.task_completed = False
|
|
|
|
# 失败任务
|
|
task = _make_task("fail")
|
|
await agent.execute(task)
|
|
assert agent.task_started is True
|
|
assert agent.task_failed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_wraps_timing():
|
|
agent = SimpleAgent()
|
|
task = _make_task("echo")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.started_at is not None
|
|
assert result.completed_at is not None
|
|
assert result.metrics["elapsed_seconds"] >= 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_status():
|
|
agent = SimpleAgent()
|
|
assert agent.status == AgentStatus.OFFLINE
|
|
assert agent.is_distributed is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_injection():
|
|
from agentkit.tools.function_tool import FunctionTool
|
|
|
|
async def my_tool(x: int) -> dict:
|
|
return {"doubled": x * 2}
|
|
|
|
tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool)
|
|
agent = SimpleAgent()
|
|
agent.use_tool(tool)
|
|
|
|
assert len(agent.tools) == 1
|
|
assert agent.tools[0].name == "doubler"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_returns_failed_result():
|
|
"""Task exceeding timeout_seconds returns FAILED TaskResult with TaskTimeoutError"""
|
|
agent = SimpleAgent()
|
|
# slow task sleeps 10s, timeout 0.1s
|
|
task = _make_task("slow", timeout_seconds=0)
|
|
task = TaskMessage(
|
|
task_id="timeout-001",
|
|
agent_name="test_agent",
|
|
task_type="slow",
|
|
priority=0,
|
|
input_data={},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
timeout_seconds=0, # Will use 0.1 via direct call
|
|
)
|
|
# Override: use a task with very short timeout
|
|
task_short = TaskMessage(
|
|
task_id="timeout-001",
|
|
agent_name="test_agent",
|
|
task_type="slow",
|
|
priority=0,
|
|
input_data={},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
timeout_seconds=1, # 1s timeout, but slow sleeps 10s
|
|
)
|
|
result = await agent.execute(task_short)
|
|
|
|
assert result.status == TaskStatus.FAILED
|
|
assert "timed out" in result.error_message
|
|
assert result.metrics["error_type"] == "TaskTimeoutError"
|
|
assert agent.task_failed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_task_sets_token():
|
|
"""cancel_task() sets the CancellationToken for a running task"""
|
|
agent = SimpleAgent()
|
|
|
|
# Start a slow task in background
|
|
task = TaskMessage(
|
|
task_id="cancel-001",
|
|
agent_name="test_agent",
|
|
task_type="slow",
|
|
priority=0,
|
|
input_data={},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
timeout_seconds=0, # no timeout
|
|
)
|
|
|
|
exec_task = asyncio.create_task(agent.execute(task))
|
|
|
|
# Give the task a moment to start and register its token
|
|
await asyncio.sleep(0.05)
|
|
|
|
# Cancel the task
|
|
cancelled = agent.cancel_task("cancel-001")
|
|
assert cancelled is True
|
|
|
|
# Wait for the task to complete
|
|
result = await exec_task
|
|
assert result.status == TaskStatus.CANCELLED
|
|
assert "cancelled" in result.error_message
|
|
|
|
# After task completes, token should be cleaned up
|
|
assert "cancel-001" not in agent._active_tokens
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancel_nonexistent_task_returns_false():
|
|
"""Cancelling a task that doesn't exist returns False"""
|
|
agent = SimpleAgent()
|
|
assert agent.cancel_task("nonexistent") is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancellation_token_protocol():
|
|
"""CancellationToken basic protocol: cancel, is_cancelled, check"""
|
|
token = CancellationToken()
|
|
assert token.is_cancelled is False
|
|
|
|
token.cancel()
|
|
assert token.is_cancelled is True
|
|
|
|
with pytest.raises(TaskCancelledError):
|
|
token.check()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_zero_means_no_timeout():
|
|
"""timeout_seconds=0 means no timeout enforcement"""
|
|
agent = SimpleAgent()
|
|
# echo task is fast, timeout=0 should not interfere
|
|
task = _make_task("echo", {"msg": "hello"}, timeout_seconds=0)
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert result.output_data == {"echo": {"msg": "hello"}}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_active_tokens_cleaned_up_after_completion():
|
|
"""CancellationToken is removed from _active_tokens after task completes"""
|
|
agent = SimpleAgent()
|
|
task = _make_task("echo")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert "test-001" not in agent._active_tokens
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_status_lock_exists():
|
|
"""BaseAgent has an asyncio.Lock for status updates"""
|
|
agent = SimpleAgent()
|
|
assert hasattr(agent, "_status_lock")
|
|
assert isinstance(agent._status_lock, asyncio.Lock)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_status_updates_no_race():
|
|
"""Concurrent _execute_task calls don't cause race conditions on status"""
|
|
agent = SimpleAgent()
|
|
|
|
# Use a slow agent to ensure tasks overlap
|
|
class SlowAgent(BaseAgent):
|
|
def __init__(self):
|
|
super().__init__(name="slow_agent", agent_type="test", version="1.0.0")
|
|
self._barrier = asyncio.Barrier(3)
|
|
|
|
async def handle_task(self, task: TaskMessage) -> dict:
|
|
# All tasks wait at barrier so they run concurrently
|
|
await self._barrier.wait()
|
|
return {"result": "ok"}
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["test"],
|
|
max_concurrency=10,
|
|
description="Slow test agent",
|
|
)
|
|
|
|
slow_agent = SlowAgent()
|
|
slow_agent._status = AgentStatus.ONLINE
|
|
slow_agent._semaphore = asyncio.Semaphore(10)
|
|
|
|
# Launch 3 concurrent tasks
|
|
tasks_list = []
|
|
for i in range(3):
|
|
task = TaskMessage(
|
|
task_id=f"concurrent-{i}",
|
|
agent_name="slow_agent",
|
|
task_type="test",
|
|
priority=0,
|
|
input_data={},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
timeout_seconds=0,
|
|
)
|
|
tasks_list.append(asyncio.create_task(slow_agent._execute_task(task)))
|
|
|
|
# Wait for all tasks to complete
|
|
await asyncio.gather(*tasks_list)
|
|
|
|
# After all tasks complete, status should be ONLINE and no running tasks
|
|
assert slow_agent.status == AgentStatus.ONLINE
|
|
assert len(slow_agent._running_tasks) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_status_lock_serializes_transitions():
|
|
"""Status lock properly serializes status transitions"""
|
|
agent = SimpleAgent()
|
|
agent._status = AgentStatus.ONLINE
|
|
agent._semaphore = asyncio.Semaphore(10)
|
|
|
|
transition_order = []
|
|
|
|
async def record_status_transition(task_id: str):
|
|
async with agent._status_lock:
|
|
agent._running_tasks.add(task_id)
|
|
transition_order.append(f"busy-{task_id}")
|
|
agent._status = AgentStatus.BUSY
|
|
|
|
# Simulate some work
|
|
await asyncio.sleep(0.01)
|
|
|
|
async with agent._status_lock:
|
|
agent._running_tasks.discard(task_id)
|
|
if not agent._running_tasks:
|
|
transition_order.append(f"online-{task_id}")
|
|
agent._status = AgentStatus.ONLINE
|
|
|
|
# Run two transitions concurrently
|
|
await asyncio.gather(
|
|
record_status_transition("t1"),
|
|
record_status_transition("t2"),
|
|
)
|
|
|
|
# Both busy transitions should happen before any online transition
|
|
busy_indices = [i for i, t in enumerate(transition_order) if t.startswith("busy")]
|
|
online_indices = [i for i, t in enumerate(transition_order) if t.startswith("online")]
|
|
assert all(bi < oi for bi in busy_indices for oi in online_indices)
|
|
assert agent.status == AgentStatus.ONLINE
|