fischer-agentkit/tests/unit/test_base_agent.py

357 lines
11 KiB
Python

"""Tests for BaseAgent - 统一生命周期"""
import asyncio
import pytest
from agentkit.core.base import BaseAgent
from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
CancellationToken,
TaskMessage,
TaskResult,
TaskStatus,
)
from datetime import datetime, timezone
class SimpleAgent(BaseAgent):
"""测试用简单 Agent"""
def __init__(self):
super().__init__(name="test_agent", agent_type="test", version="1.0.0")
self.task_started = False
self.task_completed = False
self.task_failed = False
async def handle_task(self, task: TaskMessage) -> dict:
if task.task_type == "echo":
return {"echo": task.input_data}
elif task.task_type == "fail":
raise ValueError("intentional failure")
elif task.task_type == "slow":
await asyncio.sleep(10)
return {"status": "slow_done"}
return {"status": "ok"}
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["echo", "fail", "slow"],
max_concurrency=2,
description="Test agent",
)
async def on_task_start(self, task):
self.task_started = True
async def on_task_complete(self, task, output):
self.task_completed = True
async def on_task_failed(self, task, error):
self.task_failed = True
def _make_task(task_type: str = "echo", input_data: dict | None = None, timeout_seconds: int = 300) -> TaskMessage:
return TaskMessage(
task_id="test-001",
agent_name="test_agent",
task_type=task_type,
priority=0,
input_data=input_data or {},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=timeout_seconds,
)
@pytest.mark.asyncio
async def test_handle_task_returns_output():
agent = SimpleAgent()
task = _make_task("echo", {"msg": "hello"})
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data == {"echo": {"msg": "hello"}}
assert result.error_message is None
assert result.metrics["task_type"] == "echo"
@pytest.mark.asyncio
async def test_handle_task_failure():
agent = SimpleAgent()
task = _make_task("fail")
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message == "intentional failure"
assert result.metrics["error_type"] == "ValueError"
@pytest.mark.asyncio
async def test_lifecycle_hooks():
agent = SimpleAgent()
# 成功任务
task = _make_task("echo")
await agent.execute(task)
assert agent.task_started is True
assert agent.task_completed is True
# 重置
agent.task_started = False
agent.task_completed = False
# 失败任务
task = _make_task("fail")
await agent.execute(task)
assert agent.task_started is True
assert agent.task_failed is True
@pytest.mark.asyncio
async def test_execute_wraps_timing():
agent = SimpleAgent()
task = _make_task("echo")
result = await agent.execute(task)
assert result.started_at is not None
assert result.completed_at is not None
assert result.metrics["elapsed_seconds"] >= 0
@pytest.mark.asyncio
async def test_agent_status():
agent = SimpleAgent()
assert agent.status == AgentStatus.OFFLINE
assert agent.is_distributed is False
@pytest.mark.asyncio
async def test_tool_injection():
from agentkit.tools.function_tool import FunctionTool
async def my_tool(x: int) -> dict:
return {"doubled": x * 2}
tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool)
agent = SimpleAgent()
agent.use_tool(tool)
assert len(agent.tools) == 1
assert agent.tools[0].name == "doubler"
@pytest.mark.asyncio
async def test_timeout_returns_failed_result():
"""Task exceeding timeout_seconds returns FAILED TaskResult with TaskTimeoutError"""
agent = SimpleAgent()
# slow task sleeps 10s, timeout 0.1s
task = _make_task("slow", timeout_seconds=0)
task = TaskMessage(
task_id="timeout-001",
agent_name="test_agent",
task_type="slow",
priority=0,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=0, # Will use 0.1 via direct call
)
# Override: use a task with very short timeout
task_short = TaskMessage(
task_id="timeout-001",
agent_name="test_agent",
task_type="slow",
priority=0,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=1, # 1s timeout, but slow sleeps 10s
)
result = await agent.execute(task_short)
assert result.status == TaskStatus.FAILED
assert "timed out" in result.error_message
assert result.metrics["error_type"] == "TaskTimeoutError"
assert agent.task_failed is True
@pytest.mark.asyncio
async def test_cancel_task_sets_token():
"""cancel_task() sets the CancellationToken for a running task"""
agent = SimpleAgent()
# Start a slow task in background
task = TaskMessage(
task_id="cancel-001",
agent_name="test_agent",
task_type="slow",
priority=0,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=0, # no timeout
)
exec_task = asyncio.create_task(agent.execute(task))
# Give the task a moment to start and register its token
await asyncio.sleep(0.05)
# Cancel the task
cancelled = agent.cancel_task("cancel-001")
assert cancelled is True
# Wait for the task to complete
result = await exec_task
assert result.status == TaskStatus.CANCELLED
assert "cancelled" in result.error_message
# After task completes, token should be cleaned up
assert "cancel-001" not in agent._active_tokens
@pytest.mark.asyncio
async def test_cancel_nonexistent_task_returns_false():
"""Cancelling a task that doesn't exist returns False"""
agent = SimpleAgent()
assert agent.cancel_task("nonexistent") is False
@pytest.mark.asyncio
async def test_cancellation_token_protocol():
"""CancellationToken basic protocol: cancel, is_cancelled, check"""
token = CancellationToken()
assert token.is_cancelled is False
token.cancel()
assert token.is_cancelled is True
with pytest.raises(TaskCancelledError):
token.check()
@pytest.mark.asyncio
async def test_timeout_zero_means_no_timeout():
"""timeout_seconds=0 means no timeout enforcement"""
agent = SimpleAgent()
# echo task is fast, timeout=0 should not interfere
task = _make_task("echo", {"msg": "hello"}, timeout_seconds=0)
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data == {"echo": {"msg": "hello"}}
@pytest.mark.asyncio
async def test_active_tokens_cleaned_up_after_completion():
"""CancellationToken is removed from _active_tokens after task completes"""
agent = SimpleAgent()
task = _make_task("echo")
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert "test-001" not in agent._active_tokens
@pytest.mark.asyncio
async def test_status_lock_exists():
"""BaseAgent has an asyncio.Lock for status updates"""
agent = SimpleAgent()
assert hasattr(agent, "_status_lock")
assert isinstance(agent._status_lock, asyncio.Lock)
@pytest.mark.asyncio
async def test_concurrent_status_updates_no_race():
"""Concurrent _execute_task calls don't cause race conditions on status"""
agent = SimpleAgent()
# Use a slow agent to ensure tasks overlap
class SlowAgent(BaseAgent):
def __init__(self):
super().__init__(name="slow_agent", agent_type="test", version="1.0.0")
self._barrier = asyncio.Barrier(3)
async def handle_task(self, task: TaskMessage) -> dict:
# All tasks wait at barrier so they run concurrently
await self._barrier.wait()
return {"result": "ok"}
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["test"],
max_concurrency=10,
description="Slow test agent",
)
slow_agent = SlowAgent()
slow_agent._status = AgentStatus.ONLINE
slow_agent._semaphore = asyncio.Semaphore(10)
# Launch 3 concurrent tasks
tasks_list = []
for i in range(3):
task = TaskMessage(
task_id=f"concurrent-{i}",
agent_name="slow_agent",
task_type="test",
priority=0,
input_data={},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=0,
)
tasks_list.append(asyncio.create_task(slow_agent._execute_task(task)))
# Wait for all tasks to complete
await asyncio.gather(*tasks_list)
# After all tasks complete, status should be ONLINE and no running tasks
assert slow_agent.status == AgentStatus.ONLINE
assert len(slow_agent._running_tasks) == 0
@pytest.mark.asyncio
async def test_status_lock_serializes_transitions():
"""Status lock properly serializes status transitions"""
agent = SimpleAgent()
agent._status = AgentStatus.ONLINE
agent._semaphore = asyncio.Semaphore(10)
transition_order = []
async def record_status_transition(task_id: str):
async with agent._status_lock:
agent._running_tasks.add(task_id)
transition_order.append(f"busy-{task_id}")
agent._status = AgentStatus.BUSY
# Simulate some work
await asyncio.sleep(0.01)
async with agent._status_lock:
agent._running_tasks.discard(task_id)
if not agent._running_tasks:
transition_order.append(f"online-{task_id}")
agent._status = AgentStatus.ONLINE
# Run two transitions concurrently
await asyncio.gather(
record_status_transition("t1"),
record_status_transition("t2"),
)
# Both busy transitions should happen before any online transition
busy_indices = [i for i, t in enumerate(transition_order) if t.startswith("busy")]
online_indices = [i for i, t in enumerate(transition_order) if t.startswith("online")]
assert all(bi < oi for bi in busy_indices for oi in online_indices)
assert agent.status == AgentStatus.ONLINE