191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
"""Agent 配置管理 - 支持热更新"""
|
|
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.agent_framework.exceptions import AgentNotFoundError, ConfigValidationError
|
|
from app.database import AsyncSessionLocal
|
|
from app.models.agent import AgentConfig as AgentConfigModel
|
|
from app.models.agent import AgentRegistry as AgentRegistryModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentConfigManager:
|
|
"""Agent 配置管理,支持热更新"""
|
|
|
|
async def get_config(self, agent_name: str) -> dict:
|
|
"""获取 Agent 的完整配置"""
|
|
async with AsyncSessionLocal() as db:
|
|
agent = await self._get_agent_by_name(db, agent_name)
|
|
if not agent:
|
|
raise AgentNotFoundError(agent_name)
|
|
|
|
stmt = select(AgentConfigModel).where(
|
|
AgentConfigModel.agent_id == agent.id
|
|
)
|
|
result = await db.execute(stmt)
|
|
configs = result.scalars().all()
|
|
|
|
return {
|
|
"agent_name": agent_name,
|
|
"agent_id": str(agent.id),
|
|
"configs": {
|
|
cfg.config_key: cfg.config_value for cfg in configs
|
|
},
|
|
}
|
|
|
|
async def set_config(
|
|
self,
|
|
agent_name: str,
|
|
key: str,
|
|
value: Any,
|
|
updated_by: str | None = None,
|
|
):
|
|
"""设置单个配置项"""
|
|
async with AsyncSessionLocal() as db:
|
|
try:
|
|
agent = await self._get_agent_by_name(db, agent_name)
|
|
if not agent:
|
|
raise AgentNotFoundError(agent_name)
|
|
|
|
# 查找已有配置
|
|
stmt = select(AgentConfigModel).where(
|
|
AgentConfigModel.agent_id == agent.id,
|
|
AgentConfigModel.config_key == key,
|
|
)
|
|
result = await db.execute(stmt)
|
|
existing = result.scalar_one_or_none()
|
|
|
|
# 将 value 包装为 JSONB 兼容的 dict
|
|
config_value = self._wrap_value(value)
|
|
|
|
if existing:
|
|
existing.config_value = config_value
|
|
if updated_by:
|
|
existing.updated_by = uuid.UUID(updated_by)
|
|
else:
|
|
new_config = AgentConfigModel(
|
|
agent_id=agent.id,
|
|
config_key=key,
|
|
config_value=config_value,
|
|
updated_by=uuid.UUID(updated_by) if updated_by else None,
|
|
)
|
|
db.add(new_config)
|
|
|
|
await db.commit()
|
|
logger.info(f"Config '{key}' updated for agent '{agent_name}'")
|
|
|
|
except AgentNotFoundError:
|
|
raise
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Failed to set config '{key}' for agent '{agent_name}': {e}")
|
|
raise
|
|
|
|
async def bulk_update_config(
|
|
self,
|
|
agent_name: str,
|
|
configs: dict,
|
|
updated_by: str | None = None,
|
|
):
|
|
"""批量更新配置"""
|
|
async with AsyncSessionLocal() as db:
|
|
try:
|
|
agent = await self._get_agent_by_name(db, agent_name)
|
|
if not agent:
|
|
raise AgentNotFoundError(agent_name)
|
|
|
|
for key, value in configs.items():
|
|
config_value = self._wrap_value(value)
|
|
|
|
stmt = select(AgentConfigModel).where(
|
|
AgentConfigModel.agent_id == agent.id,
|
|
AgentConfigModel.config_key == key,
|
|
)
|
|
result = await db.execute(stmt)
|
|
existing = result.scalar_one_or_none()
|
|
|
|
if existing:
|
|
existing.config_value = config_value
|
|
if updated_by:
|
|
existing.updated_by = uuid.UUID(updated_by)
|
|
else:
|
|
new_config = AgentConfigModel(
|
|
agent_id=agent.id,
|
|
config_key=key,
|
|
config_value=config_value,
|
|
updated_by=uuid.UUID(updated_by) if updated_by else None,
|
|
)
|
|
db.add(new_config)
|
|
|
|
await db.commit()
|
|
logger.info(
|
|
f"Bulk config update ({len(configs)} keys) for agent '{agent_name}'"
|
|
)
|
|
|
|
except AgentNotFoundError:
|
|
raise
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(f"Failed to bulk update config for agent '{agent_name}': {e}")
|
|
raise
|
|
|
|
async def get_config_history(
|
|
self,
|
|
agent_name: str,
|
|
key: str | None = None,
|
|
) -> list[dict]:
|
|
"""
|
|
获取配置变更历史。
|
|
注意:当前模型不保留历史版本,此方法返回当前配置快照。
|
|
未来可扩展 config_versions 表来存储完整历史。
|
|
"""
|
|
async with AsyncSessionLocal() as db:
|
|
agent = await self._get_agent_by_name(db, agent_name)
|
|
if not agent:
|
|
raise AgentNotFoundError(agent_name)
|
|
|
|
stmt = select(AgentConfigModel).where(
|
|
AgentConfigModel.agent_id == agent.id,
|
|
)
|
|
if key:
|
|
stmt = stmt.where(AgentConfigModel.config_key == key)
|
|
|
|
result = await db.execute(stmt)
|
|
configs = result.scalars().all()
|
|
|
|
return [
|
|
{
|
|
"id": str(cfg.id),
|
|
"agent_name": agent_name,
|
|
"config_key": cfg.config_key,
|
|
"config_value": cfg.config_value,
|
|
"description": cfg.description,
|
|
"updated_at": cfg.updated_at.isoformat() if cfg.updated_at else None,
|
|
"updated_by": str(cfg.updated_by) if cfg.updated_by else None,
|
|
}
|
|
for cfg in configs
|
|
]
|
|
|
|
async def _get_agent_by_name(
|
|
self, db: AsyncSession, agent_name: str
|
|
) -> AgentRegistryModel | None:
|
|
"""根据名称查找 Agent"""
|
|
stmt = select(AgentRegistryModel).where(
|
|
AgentRegistryModel.name == agent_name
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
def _wrap_value(self, value: Any) -> dict:
|
|
"""将任意值包装为 JSONB 兼容的 dict"""
|
|
if isinstance(value, dict):
|
|
return value
|
|
return {"value": value}
|