140 lines
3.7 KiB
Python
140 lines
3.7 KiB
Python
"""Tests for BaseAgent - 统一生命周期"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
|
|
from agentkit.core.base import BaseAgent
|
|
from agentkit.core.protocol import (
|
|
AgentCapability,
|
|
AgentStatus,
|
|
TaskMessage,
|
|
TaskResult,
|
|
TaskStatus,
|
|
)
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
class SimpleAgent(BaseAgent):
|
|
"""测试用简单 Agent"""
|
|
|
|
def __init__(self):
|
|
super().__init__(name="test_agent", agent_type="test", version="1.0.0")
|
|
self.task_started = False
|
|
self.task_completed = False
|
|
self.task_failed = False
|
|
|
|
async def handle_task(self, task: TaskMessage) -> dict:
|
|
if task.task_type == "echo":
|
|
return {"echo": task.input_data}
|
|
elif task.task_type == "fail":
|
|
raise ValueError("intentional failure")
|
|
return {"status": "ok"}
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["echo", "fail"],
|
|
max_concurrency=2,
|
|
description="Test agent",
|
|
)
|
|
|
|
async def on_task_start(self, task):
|
|
self.task_started = True
|
|
|
|
async def on_task_complete(self, task, output):
|
|
self.task_completed = True
|
|
|
|
async def on_task_failed(self, task, error):
|
|
self.task_failed = True
|
|
|
|
|
|
def _make_task(task_type: str = "echo", input_data: dict | None = None) -> TaskMessage:
|
|
return TaskMessage(
|
|
task_id="test-001",
|
|
agent_name="test_agent",
|
|
task_type=task_type,
|
|
priority=0,
|
|
input_data=input_data or {},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_task_returns_output():
|
|
agent = SimpleAgent()
|
|
task = _make_task("echo", {"msg": "hello"})
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.COMPLETED
|
|
assert result.output_data == {"echo": {"msg": "hello"}}
|
|
assert result.error_message is None
|
|
assert result.metrics["task_type"] == "echo"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_task_failure():
|
|
agent = SimpleAgent()
|
|
task = _make_task("fail")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.status == TaskStatus.FAILED
|
|
assert result.error_message == "intentional failure"
|
|
assert result.metrics["error_type"] == "ValueError"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lifecycle_hooks():
|
|
agent = SimpleAgent()
|
|
|
|
# 成功任务
|
|
task = _make_task("echo")
|
|
await agent.execute(task)
|
|
assert agent.task_started is True
|
|
assert agent.task_completed is True
|
|
|
|
# 重置
|
|
agent.task_started = False
|
|
agent.task_completed = False
|
|
|
|
# 失败任务
|
|
task = _make_task("fail")
|
|
await agent.execute(task)
|
|
assert agent.task_started is True
|
|
assert agent.task_failed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_wraps_timing():
|
|
agent = SimpleAgent()
|
|
task = _make_task("echo")
|
|
result = await agent.execute(task)
|
|
|
|
assert result.started_at is not None
|
|
assert result.completed_at is not None
|
|
assert result.metrics["elapsed_seconds"] >= 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_status():
|
|
agent = SimpleAgent()
|
|
assert agent.status == AgentStatus.OFFLINE
|
|
assert agent.is_distributed is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_injection():
|
|
from agentkit.tools.function_tool import FunctionTool
|
|
|
|
async def my_tool(x: int) -> dict:
|
|
return {"doubled": x * 2}
|
|
|
|
tool = FunctionTool(name="doubler", description="Doubles a number", func=my_tool)
|
|
agent = SimpleAgent()
|
|
agent.use_tool(tool)
|
|
|
|
assert len(agent.tools) == 1
|
|
assert agent.tools[0].name == "doubler"
|