256 lines
8.8 KiB
Python
256 lines
8.8 KiB
Python
"""BaseAgent 基类 - 定义所有 Agent 的标准生命周期"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from abc import ABC, abstractmethod
|
||
from datetime import datetime, timezone
|
||
|
||
import redis.asyncio as aioredis
|
||
|
||
from app.agent_framework.exceptions import AgentNotReadyError
|
||
from app.agent_framework.protocol import (
|
||
AgentCapability,
|
||
AgentStatus,
|
||
TaskMessage,
|
||
TaskProgress,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Module-level lazy singleton for TaskDispatcher — avoids creating a new
|
||
# dispatcher on every method call while still deferring the import to
|
||
# prevent circular-dependency issues at module-load time.
|
||
_dispatcher_instance = None
|
||
|
||
|
||
def _get_dispatcher():
|
||
"""Return a cached TaskDispatcher instance (lazy singleton)."""
|
||
global _dispatcher_instance
|
||
if _dispatcher_instance is None:
|
||
from app.agent_framework.dispatcher import TaskDispatcher
|
||
_dispatcher_instance = TaskDispatcher(settings.REDIS_URL)
|
||
return _dispatcher_instance
|
||
|
||
|
||
# Module-level lazy singleton for AgentRegistry — same rationale.
|
||
_registry_instance = None
|
||
|
||
|
||
def _get_registry():
|
||
"""Return a cached AgentRegistry instance (lazy singleton)."""
|
||
global _registry_instance
|
||
if _registry_instance is None:
|
||
from app.agent_framework.registry import AgentRegistry
|
||
_registry_instance = AgentRegistry()
|
||
return _registry_instance
|
||
|
||
|
||
class BaseAgent(ABC):
|
||
"""所有 Agent 的基类,定义标准生命周期"""
|
||
|
||
def __init__(self, name: str, agent_type: str, version: str = "1.0.0"):
|
||
self.name = name
|
||
self.agent_type = agent_type
|
||
self.version = version
|
||
self._status: AgentStatus = AgentStatus.OFFLINE
|
||
self._redis: aioredis.Redis | None = None
|
||
self._running_tasks: set[str] = set()
|
||
self._listen_task: asyncio.Task | None = None
|
||
self._heartbeat_task: asyncio.Task | None = None
|
||
self._semaphore: asyncio.Semaphore | None = None
|
||
|
||
@property
|
||
def status(self) -> AgentStatus:
|
||
return self._status
|
||
|
||
@property
|
||
def is_distributed(self) -> bool:
|
||
return self._redis is not None
|
||
|
||
@abstractmethod
|
||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||
"""执行任务的核心逻辑,子类必须实现"""
|
||
...
|
||
|
||
@abstractmethod
|
||
def get_capabilities(self) -> AgentCapability:
|
||
"""返回 Agent 能力声明"""
|
||
...
|
||
|
||
async def start(self):
|
||
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
|
||
|
||
try:
|
||
self._redis = aioredis.from_url(
|
||
settings.REDIS_URL,
|
||
decode_responses=True,
|
||
)
|
||
await self._redis.ping()
|
||
|
||
registry = _get_registry()
|
||
capability = self.get_capabilities()
|
||
await registry.register(capability, endpoint=f"agent:{self.name}")
|
||
|
||
self._status = AgentStatus.ONLINE
|
||
|
||
capability = self.get_capabilities()
|
||
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
|
||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||
logger.info(
|
||
f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
|
||
)
|
||
|
||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||
self._listen_task = asyncio.create_task(self._listen_for_tasks())
|
||
|
||
logger.info(f"Agent '{self.name}' started in distributed mode")
|
||
except Exception as e:
|
||
self._redis = None
|
||
self._status = AgentStatus.ONLINE
|
||
|
||
capability = self.get_capabilities()
|
||
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
|
||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||
|
||
logger.warning(f"Agent '{self.name}' started in local mode (Redis unavailable: {e})")
|
||
|
||
async def stop(self):
|
||
logger.info(f"Stopping agent '{self.name}'")
|
||
|
||
self._status = AgentStatus.OFFLINE
|
||
|
||
if self._listen_task and not self._listen_task.done():
|
||
self._listen_task.cancel()
|
||
try:
|
||
await self._listen_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
if self._heartbeat_task and not self._heartbeat_task.done():
|
||
self._heartbeat_task.cancel()
|
||
try:
|
||
await self._heartbeat_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
if self._redis is not None:
|
||
registry = _get_registry()
|
||
await registry.unregister(self.name)
|
||
|
||
await self._redis.close()
|
||
self._redis = None
|
||
|
||
logger.info(f"Agent '{self.name}' stopped")
|
||
|
||
async def heartbeat(self):
|
||
"""定期心跳上报"""
|
||
registry = _get_registry()
|
||
await registry.update_heartbeat(self.name)
|
||
|
||
async def report_progress(self, task_id: str, progress: float, message: str):
|
||
progress_obj = TaskProgress(
|
||
task_id=task_id,
|
||
agent_name=self.name,
|
||
progress=progress,
|
||
message=message,
|
||
updated_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
if self._redis:
|
||
try:
|
||
await self._redis.publish(
|
||
f"agent:{self.name}:progress",
|
||
json.dumps(progress_obj.to_dict()),
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
|
||
|
||
try:
|
||
dispatcher = _get_dispatcher()
|
||
await dispatcher.handle_progress(progress_obj)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}")
|
||
|
||
async def _heartbeat_loop(self):
|
||
"""心跳循环"""
|
||
try:
|
||
while self._status == AgentStatus.ONLINE:
|
||
await self.heartbeat()
|
||
await asyncio.sleep(30) # 每30秒一次心跳
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"Heartbeat error for agent '{self.name}': {e}")
|
||
|
||
async def _listen_for_tasks(self):
|
||
"""监听 Redis 任务队列"""
|
||
try:
|
||
queue_key = f"agent:{self.name}:tasks"
|
||
while self._status == AgentStatus.ONLINE:
|
||
if not self._redis:
|
||
await asyncio.sleep(1)
|
||
continue
|
||
|
||
# 阻塞式从队列中获取任务(超时1秒,以便定期检查状态)
|
||
result = await self._redis.brpop(queue_key, timeout=1)
|
||
if result:
|
||
_, task_json = result
|
||
try:
|
||
task_data = json.loads(task_json)
|
||
task = TaskMessage.from_dict(task_data)
|
||
asyncio.create_task(self._execute_task_with_semaphore(task))
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse task message: {e}")
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
logger.error(f"Task listener error for agent '{self.name}': {e}")
|
||
|
||
async def _execute_task_with_semaphore(self, task: TaskMessage):
|
||
"""通过 Semaphore 限制并发执行任务"""
|
||
if self._semaphore is None:
|
||
await self._execute_task(task)
|
||
return
|
||
async with self._semaphore:
|
||
await self._execute_task(task)
|
||
|
||
async def _execute_task(self, task: TaskMessage):
|
||
self._running_tasks.add(task.task_id)
|
||
self._status = AgentStatus.BUSY
|
||
|
||
started_at = datetime.now(timezone.utc)
|
||
try:
|
||
logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})")
|
||
result = await self.execute(task)
|
||
result.started_at = started_at
|
||
result.completed_at = datetime.now(timezone.utc)
|
||
|
||
if self._redis is not None:
|
||
dispatcher = _get_dispatcher()
|
||
await dispatcher.handle_result(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
|
||
error_result = TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics=None,
|
||
)
|
||
if self._redis is not None:
|
||
dispatcher = _get_dispatcher()
|
||
await dispatcher.handle_result(error_result)
|
||
|
||
finally:
|
||
self._running_tasks.discard(task.task_id)
|
||
if not self._running_tasks:
|
||
self._status = AgentStatus.ONLINE
|