"""Tests for BaseAgent - 统一生命周期""" import asyncio import pytest from agentkit.core.base import BaseAgent from agentkit.core.exceptions import TaskCancelledError, TaskTimeoutError from agentkit.core.protocol import ( AgentCapability, AgentStatus, CancellationToken, TaskMessage, TaskResult, TaskStatus, ) from datetime import datetime, timezone class SimpleAgent(BaseAgent): """测试用简单 Agent""" def __init__(self): super().__init__(name="test_agent", agent_type="test", version="1.0.0") self.task_started = False self.task_completed = False self.task_failed = False async def handle_task(self, task: TaskMessage) -> dict: if task.task_type == "echo": return {"echo": task.input_data} elif task.task_type == "fail": raise ValueError("intentional failure") elif task.task_type == "slow": await asyncio.sleep(10) return {"status": "slow_done"} return {"status": "ok"} def get_capabilities(self) -> AgentCapability: return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, supported_tasks=["echo", "fail", "slow"], max_concurrency=2, description="Test agent", ) async def on_task_start(self, task): self.task_started = True async def on_task_complete(self, task, output): self.task_completed = True async def on_task_failed(self, task, error): self.task_failed = True def _make_task(task_type: str = "echo", input_data: dict | None = None, timeout_seconds: int = 300) -> TaskMessage: return TaskMessage( task_id="test-001", agent_name="test_agent", task_type=task_type, priority=0, input_data=input_data or {}, callback_url=None, created_at=datetime.now(timezone.utc), timeout_seconds=timeout_seconds, ) @pytest.mark.asyncio async def test_handle_task_returns_output(): agent = SimpleAgent() task = _make_task("echo", {"msg": "hello"}) result = await agent.execute(task) assert result.status == TaskStatus.COMPLETED assert result.output_data == {"echo": {"msg": "hello"}} assert result.error_message is None assert result.metrics["task_type"] == "echo" @pytest.mark.asyncio async def test_handle_task_failure(): agent = SimpleAgent() task = _make_task("fail") result = await agent.execute(task) assert result.status == TaskStatus.FAILED assert result.error_message == "intentional failure" assert result.metrics["error_type"] == "ValueError" @pytest.mark.asyncio async def test_lifecycle_hooks(): agent = SimpleAgent() # 成功任务 task = _make_task("echo") await agent.execute(task) assert agent.task_started is True assert agent.task_completed is True # 重置 agent.task_started = False agent.task_completed = False # 失败任务 task = _make_task("fail") await agent.execute(task) assert agent.task_started is True assert agent.task_failed is True @pytest.mark.asyncio async def test_execute_wraps_timing(): agent = SimpleAgent() task = _make_task("echo") result = await agent.execute(task) assert result.started_at is not None assert result.completed_at is not None assert result.metrics["elapsed_seconds"] >= 0 @pytest.mark.asyncio async def test_agent_status(): agent = SimpleAgent() assert agent.status == AgentStatus.OFFLINE assert agent.is_distributed is False @pytest.mark.asyncio async def test_tool_injection(): from agentkit.tools.function_tool import FunctionTool async def my_tool(x: int) -> dict: return {"doubled": x * 2} tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool) agent = SimpleAgent() agent.use_tool(tool) assert len(agent.tools) == 1 assert agent.tools[0].name == "doubler" @pytest.mark.asyncio async def test_timeout_returns_failed_result(): """Task exceeding timeout_seconds returns FAILED TaskResult with TaskTimeoutError""" agent = SimpleAgent() # slow task sleeps 10s, timeout 0.1s task = _make_task("slow", timeout_seconds=0) task = TaskMessage( task_id="timeout-001", agent_name="test_agent", task_type="slow", priority=0, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), timeout_seconds=0, # Will use 0.1 via direct call ) # Override: use a task with very short timeout task_short = TaskMessage( task_id="timeout-001", agent_name="test_agent", task_type="slow", priority=0, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), timeout_seconds=1, # 1s timeout, but slow sleeps 10s ) result = await agent.execute(task_short) assert result.status == TaskStatus.FAILED assert "timed out" in result.error_message assert result.metrics["error_type"] == "TaskTimeoutError" assert agent.task_failed is True @pytest.mark.asyncio async def test_cancel_task_sets_token(): """cancel_task() sets the CancellationToken for a running task""" agent = SimpleAgent() # Start a slow task in background task = TaskMessage( task_id="cancel-001", agent_name="test_agent", task_type="slow", priority=0, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), timeout_seconds=0, # no timeout ) exec_task = asyncio.create_task(agent.execute(task)) # Give the task a moment to start and register its token await asyncio.sleep(0.05) # Cancel the task cancelled = agent.cancel_task("cancel-001") assert cancelled is True # Wait for the task to complete result = await exec_task assert result.status == TaskStatus.CANCELLED assert "cancelled" in result.error_message # After task completes, token should be cleaned up assert "cancel-001" not in agent._active_tokens @pytest.mark.asyncio async def test_cancel_nonexistent_task_returns_false(): """Cancelling a task that doesn't exist returns False""" agent = SimpleAgent() assert agent.cancel_task("nonexistent") is False @pytest.mark.asyncio async def test_cancellation_token_protocol(): """CancellationToken basic protocol: cancel, is_cancelled, check""" token = CancellationToken() assert token.is_cancelled is False token.cancel() assert token.is_cancelled is True with pytest.raises(TaskCancelledError): token.check() @pytest.mark.asyncio async def test_timeout_zero_means_no_timeout(): """timeout_seconds=0 means no timeout enforcement""" agent = SimpleAgent() # echo task is fast, timeout=0 should not interfere task = _make_task("echo", {"msg": "hello"}, timeout_seconds=0) result = await agent.execute(task) assert result.status == TaskStatus.COMPLETED assert result.output_data == {"echo": {"msg": "hello"}} @pytest.mark.asyncio async def test_active_tokens_cleaned_up_after_completion(): """CancellationToken is removed from _active_tokens after task completes""" agent = SimpleAgent() task = _make_task("echo") result = await agent.execute(task) assert result.status == TaskStatus.COMPLETED assert "test-001" not in agent._active_tokens @pytest.mark.asyncio async def test_status_lock_exists(): """BaseAgent has an asyncio.Lock for status updates""" agent = SimpleAgent() assert hasattr(agent, "_status_lock") assert isinstance(agent._status_lock, asyncio.Lock) @pytest.mark.asyncio async def test_concurrent_status_updates_no_race(): """Concurrent _execute_task calls don't cause race conditions on status""" agent = SimpleAgent() # Use a slow agent to ensure tasks overlap class SlowAgent(BaseAgent): def __init__(self): super().__init__(name="slow_agent", agent_type="test", version="1.0.0") self._barrier = asyncio.Barrier(3) async def handle_task(self, task: TaskMessage) -> dict: # All tasks wait at barrier so they run concurrently await self._barrier.wait() return {"result": "ok"} def get_capabilities(self) -> AgentCapability: return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, supported_tasks=["test"], max_concurrency=10, description="Slow test agent", ) slow_agent = SlowAgent() slow_agent._status = AgentStatus.ONLINE slow_agent._semaphore = asyncio.Semaphore(10) # Launch 3 concurrent tasks tasks_list = [] for i in range(3): task = TaskMessage( task_id=f"concurrent-{i}", agent_name="slow_agent", task_type="test", priority=0, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), timeout_seconds=0, ) tasks_list.append(asyncio.create_task(slow_agent._execute_task(task))) # Wait for all tasks to complete await asyncio.gather(*tasks_list) # After all tasks complete, status should be ONLINE and no running tasks assert slow_agent.status == AgentStatus.ONLINE assert len(slow_agent._running_tasks) == 0 @pytest.mark.asyncio async def test_status_lock_serializes_transitions(): """Status lock properly serializes status transitions""" agent = SimpleAgent() agent._status = AgentStatus.ONLINE agent._semaphore = asyncio.Semaphore(10) transition_order = [] async def record_status_transition(task_id: str): async with agent._status_lock: agent._running_tasks.add(task_id) transition_order.append(f"busy-{task_id}") agent._status = AgentStatus.BUSY # Simulate some work await asyncio.sleep(0.01) async with agent._status_lock: agent._running_tasks.discard(task_id) if not agent._running_tasks: transition_order.append(f"online-{task_id}") agent._status = AgentStatus.ONLINE # Run two transitions concurrently await asyncio.gather( record_status_transition("t1"), record_status_transition("t2"), ) # Both busy transitions should happen before any online transition busy_indices = [i for i, t in enumerate(transition_order) if t.startswith("busy")] online_indices = [i for i, t in enumerate(transition_order) if t.startswith("online")] assert all(bi < oi for bi in busy_indices for oi in online_indices) assert agent.status == AgentStatus.ONLINE