geo/backend/tests/test_agent_framework/test_agent_base.py

228 lines
6.7 KiB
Python

"""Agent基类测试"""
import pytest
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentStatus,
TaskMessage,
TaskResult,
TaskStatus,
)
class ConcreteTestAgent(BaseAgent):
"""用于测试的BaseAgent实现"""
def __init__(self):
super().__init__(
name="concrete_test_agent",
agent_type="test_type",
version="1.0.0",
)
self._execute_called = False
self._execute_task_data = None
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["test_task"],
max_concurrency=3,
description="测试用Agent",
)
async def execute(self, task: TaskMessage) -> TaskResult:
self._execute_called = True
self._execute_task_data = task
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data={"result": "success"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"test": True},
)
class TestBaseAgent:
"""Agent基类测试"""
def test_agent_initialization(self):
"""测试Agent初始化"""
agent = ConcreteTestAgent()
assert agent.name == "concrete_test_agent"
assert agent.agent_type == "test_type"
assert agent.version == "1.0.0"
assert agent.status == AgentStatus.OFFLINE
def test_get_capabilities(self):
"""测试获取Agent能力"""
agent = ConcreteTestAgent()
capability = agent.get_capabilities()
assert isinstance(capability, AgentCapability)
assert capability.agent_name == "concrete_test_agent"
assert "test_task" in capability.supported_tasks
def test_agent_status_transitions(self):
"""测试Agent状态转换"""
agent = ConcreteTestAgent()
# 初始状态
assert agent.status == AgentStatus.OFFLINE
# 模拟设置状态
agent._status = AgentStatus.ONLINE
assert agent.status == AgentStatus.ONLINE
agent._status = AgentStatus.BUSY
assert agent.status == AgentStatus.BUSY
def test_agent_running_tasks_tracking(self):
"""测试运行任务跟踪"""
agent = ConcreteTestAgent()
assert len(agent._running_tasks) == 0
# 模拟添加任务
agent._running_tasks.add("task-1")
agent._running_tasks.add("task-2")
assert len(agent._running_tasks) == 2
# 模拟移除任务
agent._running_tasks.discard("task-1")
assert len(agent._running_tasks) == 1
assert "task-1" not in agent._running_tasks
assert "task-2" in agent._running_tasks
def test_agent_semaphore_initialization(self):
"""测试信号量初始化"""
agent = ConcreteTestAgent()
capability = agent.get_capabilities()
max_concurrency = capability.max_concurrency
# 初始化信号量
import asyncio
agent._semaphore = asyncio.Semaphore(max_concurrency)
assert agent._semaphore is not None
assert agent._semaphore._value == max_concurrency
def test_is_idle_property(self):
"""测试空闲状态判断"""
agent = ConcreteTestAgent()
# OFFLINE 或 ONLINE 状态应该被视为空闲
agent._status = AgentStatus.OFFLINE
# 这里假设有is_idle属性或方法
# 根据实际实现检查
# BUSY状态不是空闲
agent._status = AgentStatus.BUSY
class TestTaskMessage:
"""TaskMessage测试"""
def test_task_message_creation(self):
"""测试TaskMessage创建"""
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name="test_agent",
task_type="test_task",
priority=5,
input_data={"key": "value"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
assert task.task_id is not None
assert task.agent_name == "test_agent"
assert task.task_type == "test_task"
assert task.priority == 5
assert task.input_data == {"key": "value"}
def test_task_message_to_dict(self):
"""测试TaskMessage序列化"""
task = TaskMessage(
task_id="test-uuid",
agent_name="test_agent",
task_type="test_task",
priority=5,
input_data={"key": "value"},
callback_url=None,
created_at=datetime.now(timezone.utc),
)
data = task.to_dict()
assert data["task_id"] == "test-uuid"
assert data["agent_name"] == "test_agent"
assert data["task_type"] == "test_task"
def test_task_message_from_dict(self):
"""测试TaskMessage反序列化"""
data = {
"task_id": "test-uuid",
"agent_name": "test_agent",
"task_type": "test_task",
"priority": 5,
"input_data": {"key": "value"},
"callback_url": None,
"created_at": datetime.now(timezone.utc).isoformat(),
}
task = TaskMessage.from_dict(data)
assert task.task_id == "test-uuid"
assert task.agent_name == "test_agent"
class TestTaskResult:
"""TaskResult测试"""
def test_task_result_creation(self):
"""测试TaskResult创建"""
result = TaskResult(
task_id="test-uuid",
agent_name="test_agent",
status=TaskStatus.COMPLETED,
output_data={"result": "success"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"elapsed": 1.5},
)
assert result.task_id == "test-uuid"
assert result.agent_name == "test_agent"
assert result.status == TaskStatus.COMPLETED
def test_task_result_to_dict(self):
"""测试TaskResult序列化"""
result = TaskResult(
task_id="test-uuid",
agent_name="test_agent",
status=TaskStatus.COMPLETED,
output_data={"result": "success"},
error_message=None,
started_at=datetime.now(timezone.utc),
completed_at=datetime.now(timezone.utc),
metrics={"elapsed": 1.5},
)
data = result.to_dict()
assert data["task_id"] == "test-uuid"
assert data["status"] == TaskStatus.COMPLETED
assert data["output_data"] == {"result": "success"}
import uuid