geo/backend/app/agent_framework/registry.py

219 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent 注册中心 - 管理 Agent 的注册、发现、状态"""
import logging
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.agent_framework.exceptions import (
AgentAlreadyRegisteredError,
AgentNotFoundError,
AgentUnavailableError,
NoAvailableAgentError,
)
from app.agent_framework.protocol import AgentCapability, AgentStatus
from app.database import AsyncSessionLocal
from app.models.agent import AgentRegistry as AgentRegistryModel
logger = logging.getLogger(__name__)
# 心跳超时阈值(秒)
HEARTBEAT_TIMEOUT_SECONDS = 90
class AgentRegistry:
"""Agent 注册中心,管理 Agent 的注册、发现、状态"""
async def register(self, capability: AgentCapability, endpoint: str) -> str:
"""
注册 Agent返回 agent_id。
如果同名 Agent 已存在则更新,否则新建。
"""
async with AsyncSessionLocal() as db:
try:
# 检查是否已存在
stmt = select(AgentRegistryModel).where(
AgentRegistryModel.name == capability.agent_name
)
result = await db.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
# 更新已有记录
existing.agent_type = capability.agent_type
existing.version = capability.version
existing.endpoint = endpoint
existing.description = capability.description
existing.capabilities = capability.to_dict()
existing.status = AgentStatus.ONLINE
existing.last_heartbeat = datetime.now(timezone.utc)
await db.commit()
await db.refresh(existing)
agent_id = existing.id
logger.info(f"Agent '{capability.agent_name}' re-registered (id={agent_id})")
else:
# 新建记录
agent = AgentRegistryModel(
name=capability.agent_name,
display_name=capability.agent_name.replace("_", " ").title(),
agent_type=capability.agent_type,
description=capability.description,
version=capability.version,
endpoint=endpoint,
status=AgentStatus.ONLINE,
capabilities=capability.to_dict(),
last_heartbeat=datetime.now(timezone.utc),
)
db.add(agent)
await db.commit()
await db.refresh(agent)
agent_id = agent.id
logger.info(f"Agent '{capability.agent_name}' registered (id={agent_id})")
return str(agent_id)
except Exception as e:
await db.rollback()
logger.error(f"Failed to register agent '{capability.agent_name}': {e}")
raise
async def unregister(self, agent_name: str):
"""注销 Agent设置状态为 offline"""
async with AsyncSessionLocal() as db:
try:
stmt = select(AgentRegistryModel).where(
AgentRegistryModel.name == agent_name
)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
logger.warning(f"Attempted to unregister non-existent agent '{agent_name}'")
return
agent.status = AgentStatus.OFFLINE
await db.commit()
logger.info(f"Agent '{agent_name}' unregistered")
except Exception as e:
await db.rollback()
logger.error(f"Failed to unregister agent '{agent_name}': {e}")
raise
async def update_heartbeat(self, agent_name: str):
"""更新心跳时间"""
async with AsyncSessionLocal() as db:
try:
stmt = (
update(AgentRegistryModel)
.where(AgentRegistryModel.name == agent_name)
.values(
last_heartbeat=datetime.now(timezone.utc),
status=AgentStatus.ONLINE,
)
)
await db.execute(stmt)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to update heartbeat for agent '{agent_name}': {e}")
async def get_agent(self, agent_name: str) -> dict | None:
"""获取 Agent 信息"""
async with AsyncSessionLocal() as db:
stmt = select(AgentRegistryModel).where(
AgentRegistryModel.name == agent_name
)
result = await db.execute(stmt)
agent = result.scalar_one_or_none()
if not agent:
return None
return self._agent_to_dict(agent)
async def list_agents(
self,
agent_type: str | None = None,
status: str | None = None,
) -> list[dict]:
"""列出 Agent支持按类型和状态筛选"""
async with AsyncSessionLocal() as db:
stmt = select(AgentRegistryModel)
if agent_type:
stmt = stmt.where(AgentRegistryModel.agent_type == agent_type)
if status:
stmt = stmt.where(AgentRegistryModel.status == status)
stmt = stmt.order_by(AgentRegistryModel.created_at.desc())
result = await db.execute(stmt)
agents = result.scalars().all()
return [self._agent_to_dict(a) for a in agents]
async def get_available_agent(self, task_type: str) -> str | None:
"""根据任务类型找到可用 Agent返回 agent_name"""
async with AsyncSessionLocal() as db:
# 查找状态为 online 的 Agent
stmt = select(AgentRegistryModel).where(
AgentRegistryModel.status == AgentStatus.ONLINE,
)
result = await db.execute(stmt)
agents = result.scalars().all()
for agent in agents:
capabilities = agent.capabilities or {}
supported_tasks = capabilities.get("supported_tasks", [])
if task_type in supported_tasks:
return agent.name
return None
async def check_health(self):
"""检查所有 Agent 健康状态,超时标记为 offline"""
async with AsyncSessionLocal() as db:
try:
timeout_threshold = datetime.now(timezone.utc) - timedelta(
seconds=HEARTBEAT_TIMEOUT_SECONDS
)
# 将心跳超时的 Agent 标记为 offline
stmt = (
update(AgentRegistryModel)
.where(
AgentRegistryModel.status == AgentStatus.ONLINE,
AgentRegistryModel.last_heartbeat < timeout_threshold,
)
.values(status=AgentStatus.OFFLINE)
)
result = await db.execute(stmt)
await db.commit()
if result.rowcount > 0:
logger.warning(
f"Marked {result.rowcount} agent(s) as offline due to heartbeat timeout"
)
except Exception as e:
await db.rollback()
logger.error(f"Failed to check agent health: {e}")
def _agent_to_dict(self, agent: AgentRegistryModel) -> dict:
"""将 Agent ORM 对象转换为字典"""
return {
"id": str(agent.id),
"name": agent.name,
"display_name": agent.display_name,
"agent_type": agent.agent_type,
"description": agent.description,
"version": agent.version,
"endpoint": agent.endpoint,
"status": agent.status,
"capabilities": agent.capabilities,
"last_heartbeat": agent.last_heartbeat.isoformat() if agent.last_heartbeat else None,
"created_at": agent.created_at.isoformat() if agent.created_at else None,
"updated_at": agent.updated_at.isoformat() if agent.updated_at else None,
}