228 lines
6.7 KiB
Python
228 lines
6.7 KiB
Python
"""Agent基类测试"""
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
|
|
from app.agent_framework.base import BaseAgent
|
|
from app.agent_framework.protocol import (
|
|
AgentCapability,
|
|
AgentStatus,
|
|
TaskMessage,
|
|
TaskResult,
|
|
TaskStatus,
|
|
)
|
|
|
|
|
|
class ConcreteTestAgent(BaseAgent):
|
|
"""用于测试的BaseAgent实现"""
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
name="concrete_test_agent",
|
|
agent_type="test_type",
|
|
version="1.0.0",
|
|
)
|
|
self._execute_called = False
|
|
self._execute_task_data = None
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["test_task"],
|
|
max_concurrency=3,
|
|
description="测试用Agent",
|
|
)
|
|
|
|
async def execute(self, task: TaskMessage) -> TaskResult:
|
|
self._execute_called = True
|
|
self._execute_task_data = task
|
|
return TaskResult(
|
|
task_id=task.task_id,
|
|
agent_name=self.name,
|
|
status=TaskStatus.COMPLETED,
|
|
output_data={"result": "success"},
|
|
error_message=None,
|
|
started_at=datetime.now(timezone.utc),
|
|
completed_at=datetime.now(timezone.utc),
|
|
metrics={"test": True},
|
|
)
|
|
|
|
|
|
class TestBaseAgent:
|
|
"""Agent基类测试"""
|
|
|
|
def test_agent_initialization(self):
|
|
"""测试Agent初始化"""
|
|
agent = ConcreteTestAgent()
|
|
|
|
assert agent.name == "concrete_test_agent"
|
|
assert agent.agent_type == "test_type"
|
|
assert agent.version == "1.0.0"
|
|
assert agent.status == AgentStatus.OFFLINE
|
|
|
|
def test_get_capabilities(self):
|
|
"""测试获取Agent能力"""
|
|
agent = ConcreteTestAgent()
|
|
capability = agent.get_capabilities()
|
|
|
|
assert isinstance(capability, AgentCapability)
|
|
assert capability.agent_name == "concrete_test_agent"
|
|
assert "test_task" in capability.supported_tasks
|
|
|
|
def test_agent_status_transitions(self):
|
|
"""测试Agent状态转换"""
|
|
agent = ConcreteTestAgent()
|
|
|
|
# 初始状态
|
|
assert agent.status == AgentStatus.OFFLINE
|
|
|
|
# 模拟设置状态
|
|
agent._status = AgentStatus.ONLINE
|
|
assert agent.status == AgentStatus.ONLINE
|
|
|
|
agent._status = AgentStatus.BUSY
|
|
assert agent.status == AgentStatus.BUSY
|
|
|
|
def test_agent_running_tasks_tracking(self):
|
|
"""测试运行任务跟踪"""
|
|
agent = ConcreteTestAgent()
|
|
|
|
assert len(agent._running_tasks) == 0
|
|
|
|
# 模拟添加任务
|
|
agent._running_tasks.add("task-1")
|
|
agent._running_tasks.add("task-2")
|
|
assert len(agent._running_tasks) == 2
|
|
|
|
# 模拟移除任务
|
|
agent._running_tasks.discard("task-1")
|
|
assert len(agent._running_tasks) == 1
|
|
assert "task-1" not in agent._running_tasks
|
|
assert "task-2" in agent._running_tasks
|
|
|
|
def test_agent_semaphore_initialization(self):
|
|
"""测试信号量初始化"""
|
|
agent = ConcreteTestAgent()
|
|
capability = agent.get_capabilities()
|
|
max_concurrency = capability.max_concurrency
|
|
|
|
# 初始化信号量
|
|
import asyncio
|
|
agent._semaphore = asyncio.Semaphore(max_concurrency)
|
|
|
|
assert agent._semaphore is not None
|
|
assert agent._semaphore._value == max_concurrency
|
|
|
|
def test_is_idle_property(self):
|
|
"""测试空闲状态判断"""
|
|
agent = ConcreteTestAgent()
|
|
|
|
# OFFLINE 或 ONLINE 状态应该被视为空闲
|
|
agent._status = AgentStatus.OFFLINE
|
|
# 这里假设有is_idle属性或方法
|
|
# 根据实际实现检查
|
|
|
|
# BUSY状态不是空闲
|
|
agent._status = AgentStatus.BUSY
|
|
|
|
|
|
class TestTaskMessage:
|
|
"""TaskMessage测试"""
|
|
|
|
def test_task_message_creation(self):
|
|
"""测试TaskMessage创建"""
|
|
task = TaskMessage(
|
|
task_id=str(uuid.uuid4()),
|
|
agent_name="test_agent",
|
|
task_type="test_task",
|
|
priority=5,
|
|
input_data={"key": "value"},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
assert task.task_id is not None
|
|
assert task.agent_name == "test_agent"
|
|
assert task.task_type == "test_task"
|
|
assert task.priority == 5
|
|
assert task.input_data == {"key": "value"}
|
|
|
|
def test_task_message_to_dict(self):
|
|
"""测试TaskMessage序列化"""
|
|
task = TaskMessage(
|
|
task_id="test-uuid",
|
|
agent_name="test_agent",
|
|
task_type="test_task",
|
|
priority=5,
|
|
input_data={"key": "value"},
|
|
callback_url=None,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
data = task.to_dict()
|
|
|
|
assert data["task_id"] == "test-uuid"
|
|
assert data["agent_name"] == "test_agent"
|
|
assert data["task_type"] == "test_task"
|
|
|
|
def test_task_message_from_dict(self):
|
|
"""测试TaskMessage反序列化"""
|
|
data = {
|
|
"task_id": "test-uuid",
|
|
"agent_name": "test_agent",
|
|
"task_type": "test_task",
|
|
"priority": 5,
|
|
"input_data": {"key": "value"},
|
|
"callback_url": None,
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
|
|
task = TaskMessage.from_dict(data)
|
|
|
|
assert task.task_id == "test-uuid"
|
|
assert task.agent_name == "test_agent"
|
|
|
|
|
|
class TestTaskResult:
|
|
"""TaskResult测试"""
|
|
|
|
def test_task_result_creation(self):
|
|
"""测试TaskResult创建"""
|
|
result = TaskResult(
|
|
task_id="test-uuid",
|
|
agent_name="test_agent",
|
|
status=TaskStatus.COMPLETED,
|
|
output_data={"result": "success"},
|
|
error_message=None,
|
|
started_at=datetime.now(timezone.utc),
|
|
completed_at=datetime.now(timezone.utc),
|
|
metrics={"elapsed": 1.5},
|
|
)
|
|
|
|
assert result.task_id == "test-uuid"
|
|
assert result.agent_name == "test_agent"
|
|
assert result.status == TaskStatus.COMPLETED
|
|
|
|
def test_task_result_to_dict(self):
|
|
"""测试TaskResult序列化"""
|
|
result = TaskResult(
|
|
task_id="test-uuid",
|
|
agent_name="test_agent",
|
|
status=TaskStatus.COMPLETED,
|
|
output_data={"result": "success"},
|
|
error_message=None,
|
|
started_at=datetime.now(timezone.utc),
|
|
completed_at=datetime.now(timezone.utc),
|
|
metrics={"elapsed": 1.5},
|
|
)
|
|
|
|
data = result.to_dict()
|
|
|
|
assert data["task_id"] == "test-uuid"
|
|
assert data["status"] == TaskStatus.COMPLETED
|
|
assert data["output_data"] == {"result": "success"}
|
|
|
|
|
|
import uuid
|