219 lines
8.4 KiB
Python
219 lines
8.4 KiB
Python
"""Agent 注册中心 - 管理 Agent 的注册、发现、状态"""
|
||
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from sqlalchemy import select, update
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.agent_framework.exceptions import (
|
||
AgentAlreadyRegisteredError,
|
||
AgentNotFoundError,
|
||
AgentUnavailableError,
|
||
NoAvailableAgentError,
|
||
)
|
||
from app.agent_framework.protocol import AgentCapability, AgentStatus
|
||
from app.database import AsyncSessionLocal
|
||
from app.models.agent import AgentRegistry as AgentRegistryModel
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 心跳超时阈值(秒)
|
||
HEARTBEAT_TIMEOUT_SECONDS = 90
|
||
|
||
|
||
class AgentRegistry:
|
||
"""Agent 注册中心,管理 Agent 的注册、发现、状态"""
|
||
|
||
async def register(self, capability: AgentCapability, endpoint: str) -> str:
|
||
"""
|
||
注册 Agent,返回 agent_id。
|
||
如果同名 Agent 已存在则更新,否则新建。
|
||
"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
# 检查是否已存在
|
||
stmt = select(AgentRegistryModel).where(
|
||
AgentRegistryModel.name == capability.agent_name
|
||
)
|
||
result = await db.execute(stmt)
|
||
existing = result.scalar_one_or_none()
|
||
|
||
if existing:
|
||
# 更新已有记录
|
||
existing.agent_type = capability.agent_type
|
||
existing.version = capability.version
|
||
existing.endpoint = endpoint
|
||
existing.description = capability.description
|
||
existing.capabilities = capability.to_dict()
|
||
existing.status = AgentStatus.ONLINE
|
||
existing.last_heartbeat = datetime.now(timezone.utc)
|
||
await db.commit()
|
||
await db.refresh(existing)
|
||
agent_id = existing.id
|
||
logger.info(f"Agent '{capability.agent_name}' re-registered (id={agent_id})")
|
||
else:
|
||
# 新建记录
|
||
agent = AgentRegistryModel(
|
||
name=capability.agent_name,
|
||
display_name=capability.agent_name.replace("_", " ").title(),
|
||
agent_type=capability.agent_type,
|
||
description=capability.description,
|
||
version=capability.version,
|
||
endpoint=endpoint,
|
||
status=AgentStatus.ONLINE,
|
||
capabilities=capability.to_dict(),
|
||
last_heartbeat=datetime.now(timezone.utc),
|
||
)
|
||
db.add(agent)
|
||
await db.commit()
|
||
await db.refresh(agent)
|
||
agent_id = agent.id
|
||
logger.info(f"Agent '{capability.agent_name}' registered (id={agent_id})")
|
||
|
||
return str(agent_id)
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to register agent '{capability.agent_name}': {e}")
|
||
raise
|
||
|
||
async def unregister(self, agent_name: str):
|
||
"""注销 Agent(设置状态为 offline)"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
stmt = select(AgentRegistryModel).where(
|
||
AgentRegistryModel.name == agent_name
|
||
)
|
||
result = await db.execute(stmt)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
if not agent:
|
||
logger.warning(f"Attempted to unregister non-existent agent '{agent_name}'")
|
||
return
|
||
|
||
agent.status = AgentStatus.OFFLINE
|
||
await db.commit()
|
||
logger.info(f"Agent '{agent_name}' unregistered")
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to unregister agent '{agent_name}': {e}")
|
||
raise
|
||
|
||
async def update_heartbeat(self, agent_name: str):
|
||
"""更新心跳时间"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
stmt = (
|
||
update(AgentRegistryModel)
|
||
.where(AgentRegistryModel.name == agent_name)
|
||
.values(
|
||
last_heartbeat=datetime.now(timezone.utc),
|
||
status=AgentStatus.ONLINE,
|
||
)
|
||
)
|
||
await db.execute(stmt)
|
||
await db.commit()
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to update heartbeat for agent '{agent_name}': {e}")
|
||
|
||
async def get_agent(self, agent_name: str) -> dict | None:
|
||
"""获取 Agent 信息"""
|
||
async with AsyncSessionLocal() as db:
|
||
stmt = select(AgentRegistryModel).where(
|
||
AgentRegistryModel.name == agent_name
|
||
)
|
||
result = await db.execute(stmt)
|
||
agent = result.scalar_one_or_none()
|
||
|
||
if not agent:
|
||
return None
|
||
|
||
return self._agent_to_dict(agent)
|
||
|
||
async def list_agents(
|
||
self,
|
||
agent_type: str | None = None,
|
||
status: str | None = None,
|
||
) -> list[dict]:
|
||
"""列出 Agent,支持按类型和状态筛选"""
|
||
async with AsyncSessionLocal() as db:
|
||
stmt = select(AgentRegistryModel)
|
||
if agent_type:
|
||
stmt = stmt.where(AgentRegistryModel.agent_type == agent_type)
|
||
if status:
|
||
stmt = stmt.where(AgentRegistryModel.status == status)
|
||
stmt = stmt.order_by(AgentRegistryModel.created_at.desc())
|
||
|
||
result = await db.execute(stmt)
|
||
agents = result.scalars().all()
|
||
|
||
return [self._agent_to_dict(a) for a in agents]
|
||
|
||
async def get_available_agent(self, task_type: str) -> str | None:
|
||
"""根据任务类型找到可用 Agent(返回 agent_name)"""
|
||
async with AsyncSessionLocal() as db:
|
||
# 查找状态为 online 的 Agent
|
||
stmt = select(AgentRegistryModel).where(
|
||
AgentRegistryModel.status == AgentStatus.ONLINE,
|
||
)
|
||
result = await db.execute(stmt)
|
||
agents = result.scalars().all()
|
||
|
||
for agent in agents:
|
||
capabilities = agent.capabilities or {}
|
||
supported_tasks = capabilities.get("supported_tasks", [])
|
||
if task_type in supported_tasks:
|
||
return agent.name
|
||
|
||
return None
|
||
|
||
async def check_health(self):
|
||
"""检查所有 Agent 健康状态,超时标记为 offline"""
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
timeout_threshold = datetime.now(timezone.utc) - timedelta(
|
||
seconds=HEARTBEAT_TIMEOUT_SECONDS
|
||
)
|
||
|
||
# 将心跳超时的 Agent 标记为 offline
|
||
stmt = (
|
||
update(AgentRegistryModel)
|
||
.where(
|
||
AgentRegistryModel.status == AgentStatus.ONLINE,
|
||
AgentRegistryModel.last_heartbeat < timeout_threshold,
|
||
)
|
||
.values(status=AgentStatus.OFFLINE)
|
||
)
|
||
result = await db.execute(stmt)
|
||
await db.commit()
|
||
|
||
if result.rowcount > 0:
|
||
logger.warning(
|
||
f"Marked {result.rowcount} agent(s) as offline due to heartbeat timeout"
|
||
)
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to check agent health: {e}")
|
||
|
||
def _agent_to_dict(self, agent: AgentRegistryModel) -> dict:
|
||
"""将 Agent ORM 对象转换为字典"""
|
||
return {
|
||
"id": str(agent.id),
|
||
"name": agent.name,
|
||
"display_name": agent.display_name,
|
||
"agent_type": agent.agent_type,
|
||
"description": agent.description,
|
||
"version": agent.version,
|
||
"endpoint": agent.endpoint,
|
||
"status": agent.status,
|
||
"capabilities": agent.capabilities,
|
||
"last_heartbeat": agent.last_heartbeat.isoformat() if agent.last_heartbeat else None,
|
||
"created_at": agent.created_at.isoformat() if agent.created_at else None,
|
||
"updated_at": agent.updated_at.isoformat() if agent.updated_at else None,
|
||
}
|