fischer-agentkit/tests/unit/test_base_agent.py

140 lines
3.7 KiB
Python

"""Tests for BaseAgent - 统一生命周期"""
import asyncio
import pytest
from agentkit.core.base import BaseAgent
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
TaskMessage,
TaskResult,
TaskStatus,
)
from datetime import datetime, timezone
class SimpleAgent(BaseAgent):
"""测试用简单 Agent"""
def __init__(self):
super().__init__(name="test_agent", agent_type="test", version="1.0.0")
self.task_started = False
self.task_completed = False
self.task_failed = False
async def handle_task(self, task: TaskMessage) -> dict:
if task.task_type == "echo":
return {"echo": task.input_data}
elif task.task_type == "fail":
raise ValueError("intentional failure")
return {"status": "ok"}
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["echo", "fail"],
max_concurrency=2,
description="Test agent",
)
async def on_task_start(self, task):
self.task_started = True
async def on_task_complete(self, task, output):
self.task_completed = True
async def on_task_failed(self, task, error):
self.task_failed = True
def _make_task(task_type: str = "echo", input_data: dict | None = None) -> TaskMessage:
return TaskMessage(
task_id="test-001",
agent_name="test_agent",
task_type=task_type,
priority=0,
input_data=input_data or {},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
@pytest.mark.asyncio
async def test_handle_task_returns_output():
agent = SimpleAgent()
task = _make_task("echo", {"msg": "hello"})
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data == {"echo": {"msg": "hello"}}
assert result.error_message is None
assert result.metrics["task_type"] == "echo"
@pytest.mark.asyncio
async def test_handle_task_failure():
agent = SimpleAgent()
task = _make_task("fail")
result = await agent.execute(task)
assert result.status == TaskStatus.FAILED
assert result.error_message == "intentional failure"
assert result.metrics["error_type"] == "ValueError"
@pytest.mark.asyncio
async def test_lifecycle_hooks():
agent = SimpleAgent()
# 成功任务
task = _make_task("echo")
await agent.execute(task)
assert agent.task_started is True
assert agent.task_completed is True
# 重置
agent.task_started = False
agent.task_completed = False
# 失败任务
task = _make_task("fail")
await agent.execute(task)
assert agent.task_started is True
assert agent.task_failed is True
@pytest.mark.asyncio
async def test_execute_wraps_timing():
agent = SimpleAgent()
task = _make_task("echo")
result = await agent.execute(task)
assert result.started_at is not None
assert result.completed_at is not None
assert result.metrics["elapsed_seconds"] >= 0
@pytest.mark.asyncio
async def test_agent_status():
agent = SimpleAgent()
assert agent.status == AgentStatus.OFFLINE
assert agent.is_distributed is False
@pytest.mark.asyncio
async def test_tool_injection():
from agentkit.tools.function_tool import FunctionTool
async def my_tool(x: int) -> dict:
return {"doubled": x * 2}
tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool)
agent = SimpleAgent()
agent.use_tool(tool)
assert len(agent.tools) == 1
assert agent.tools[0].name == "doubler"