geo/backend/app/agent_framework/base.py

256 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""BaseAgent 基类 - 定义所有 Agent 的标准生命周期"""
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone
import redis.asyncio as aioredis
from app.agent_framework.exceptions import AgentNotReadyError
from app.agent_framework.protocol import (
AgentCapability,
AgentStatus,
TaskMessage,
TaskProgress,
TaskResult,
TaskStatus,
)
from app.config import settings
logger = logging.getLogger(__name__)
# Module-level lazy singleton for TaskDispatcher — avoids creating a new
# dispatcher on every method call while still deferring the import to
# prevent circular-dependency issues at module-load time.
_dispatcher_instance = None
def _get_dispatcher():
"""Return a cached TaskDispatcher instance (lazy singleton)."""
global _dispatcher_instance
if _dispatcher_instance is None:
from app.agent_framework.dispatcher import TaskDispatcher
_dispatcher_instance = TaskDispatcher(settings.REDIS_URL)
return _dispatcher_instance
# Module-level lazy singleton for AgentRegistry — same rationale.
_registry_instance = None
def _get_registry():
"""Return a cached AgentRegistry instance (lazy singleton)."""
global _registry_instance
if _registry_instance is None:
from app.agent_framework.registry import AgentRegistry
_registry_instance = AgentRegistry()
return _registry_instance
class BaseAgent(ABC):
"""所有 Agent 的基类,定义标准生命周期"""
def __init__(self, name: str, agent_type: str, version: str = "1.0.0"):
self.name = name
self.agent_type = agent_type
self.version = version
self._status: AgentStatus = AgentStatus.OFFLINE
self._redis: aioredis.Redis | None = None
self._running_tasks: set[str] = set()
self._listen_task: asyncio.Task | None = None
self._heartbeat_task: asyncio.Task | None = None
self._semaphore: asyncio.Semaphore | None = None
@property
def status(self) -> AgentStatus:
return self._status
@property
def is_distributed(self) -> bool:
return self._redis is not None
@abstractmethod
async def execute(self, task: TaskMessage) -> TaskResult:
"""执行任务的核心逻辑,子类必须实现"""
...
@abstractmethod
def get_capabilities(self) -> AgentCapability:
"""返回 Agent 能力声明"""
...
async def start(self):
logger.info(f"Starting agent '{self.name}' (type={self.agent_type}, version={self.version})")
try:
self._redis = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
)
await self._redis.ping()
registry = _get_registry()
capability = self.get_capabilities()
await registry.register(capability, endpoint=f"agent:{self.name}")
self._status = AgentStatus.ONLINE
capability = self.get_capabilities()
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
logger.info(
f"Agent '{self.name}' concurrency limit set to {max_concurrency}"
)
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
self._listen_task = asyncio.create_task(self._listen_for_tasks())
logger.info(f"Agent '{self.name}' started in distributed mode")
except Exception as e:
self._redis = None
self._status = AgentStatus.ONLINE
capability = self.get_capabilities()
max_concurrency = getattr(capability, 'max_concurrency', 1) or 1
self._semaphore = asyncio.Semaphore(max_concurrency)
logger.warning(f"Agent '{self.name}' started in local mode (Redis unavailable: {e})")
async def stop(self):
logger.info(f"Stopping agent '{self.name}'")
self._status = AgentStatus.OFFLINE
if self._listen_task and not self._listen_task.done():
self._listen_task.cancel()
try:
await self._listen_task
except asyncio.CancelledError:
pass
if self._heartbeat_task and not self._heartbeat_task.done():
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
if self._redis is not None:
registry = _get_registry()
await registry.unregister(self.name)
await self._redis.close()
self._redis = None
logger.info(f"Agent '{self.name}' stopped")
async def heartbeat(self):
"""定期心跳上报"""
registry = _get_registry()
await registry.update_heartbeat(self.name)
async def report_progress(self, task_id: str, progress: float, message: str):
progress_obj = TaskProgress(
task_id=task_id,
agent_name=self.name,
progress=progress,
message=message,
updated_at=datetime.now(timezone.utc),
)
if self._redis:
try:
await self._redis.publish(
f"agent:{self.name}:progress",
json.dumps(progress_obj.to_dict()),
)
except Exception as e:
logger.warning(f"Failed to publish progress for task {task_id}: {e}")
try:
dispatcher = _get_dispatcher()
await dispatcher.handle_progress(progress_obj)
except Exception as e:
logger.warning(f"Failed to report progress to dispatcher for task {task_id}: {e}")
async def _heartbeat_loop(self):
"""心跳循环"""
try:
while self._status == AgentStatus.ONLINE:
await self.heartbeat()
await asyncio.sleep(30) # 每30秒一次心跳
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Heartbeat error for agent '{self.name}': {e}")
async def _listen_for_tasks(self):
"""监听 Redis 任务队列"""
try:
queue_key = f"agent:{self.name}:tasks"
while self._status == AgentStatus.ONLINE:
if not self._redis:
await asyncio.sleep(1)
continue
# 阻塞式从队列中获取任务超时1秒以便定期检查状态
result = await self._redis.brpop(queue_key, timeout=1)
if result:
_, task_json = result
try:
task_data = json.loads(task_json)
task = TaskMessage.from_dict(task_data)
asyncio.create_task(self._execute_task_with_semaphore(task))
except Exception as e:
logger.error(f"Failed to parse task message: {e}")
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Task listener error for agent '{self.name}': {e}")
async def _execute_task_with_semaphore(self, task: TaskMessage):
"""通过 Semaphore 限制并发执行任务"""
if self._semaphore is None:
await self._execute_task(task)
return
async with self._semaphore:
await self._execute_task(task)
async def _execute_task(self, task: TaskMessage):
self._running_tasks.add(task.task_id)
self._status = AgentStatus.BUSY
started_at = datetime.now(timezone.utc)
try:
logger.info(f"Agent '{self.name}' executing task {task.task_id} (type={task.task_type})")
result = await self.execute(task)
result.started_at = started_at
result.completed_at = datetime.now(timezone.utc)
if self._redis is not None:
dispatcher = _get_dispatcher()
await dispatcher.handle_result(result)
except Exception as e:
logger.error(f"Agent '{self.name}' task {task.task_id} failed: {e}")
error_result = TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics=None,
)
if self._redis is not None:
dispatcher = _get_dispatcher()
await dispatcher.handle_result(error_result)
finally:
self._running_tasks.discard(task.task_id)
if not self._running_tasks:
self._status = AgentStatus.ONLINE