geo/backend/app/agent_framework/config_manager.py

191 lines
6.7 KiB
Python

"""Agent 配置管理 - 支持热更新"""
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.agent_framework.exceptions import AgentNotFoundError, ConfigValidationError
from app.database import AsyncSessionLocal
from app.models.agent import AgentConfig as AgentConfigModel
from app.models.agent import AgentRegistry as AgentRegistryModel
logger = logging.getLogger(__name__)
class AgentConfigManager:
"""Agent 配置管理,支持热更新"""
async def get_config(self, agent_name: str) -> dict:
"""获取 Agent 的完整配置"""
async with AsyncSessionLocal() as db:
agent = await self._get_agent_by_name(db, agent_name)
if not agent:
raise AgentNotFoundError(agent_name)
stmt = select(AgentConfigModel).where(
AgentConfigModel.agent_id == agent.id
)
result = await db.execute(stmt)
configs = result.scalars().all()
return {
"agent_name": agent_name,
"agent_id": str(agent.id),
"configs": {
cfg.config_key: cfg.config_value for cfg in configs
},
}
async def set_config(
self,
agent_name: str,
key: str,
value: Any,
updated_by: str | None = None,
):
"""设置单个配置项"""
async with AsyncSessionLocal() as db:
try:
agent = await self._get_agent_by_name(db, agent_name)
if not agent:
raise AgentNotFoundError(agent_name)
# 查找已有配置
stmt = select(AgentConfigModel).where(
AgentConfigModel.agent_id == agent.id,
AgentConfigModel.config_key == key,
)
result = await db.execute(stmt)
existing = result.scalar_one_or_none()
# 将 value 包装为 JSONB 兼容的 dict
config_value = self._wrap_value(value)
if existing:
existing.config_value = config_value
if updated_by:
existing.updated_by = uuid.UUID(updated_by)
else:
new_config = AgentConfigModel(
agent_id=agent.id,
config_key=key,
config_value=config_value,
updated_by=uuid.UUID(updated_by) if updated_by else None,
)
db.add(new_config)
await db.commit()
logger.info(f"Config '{key}' updated for agent '{agent_name}'")
except AgentNotFoundError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to set config '{key}' for agent '{agent_name}': {e}")
raise
async def bulk_update_config(
self,
agent_name: str,
configs: dict,
updated_by: str | None = None,
):
"""批量更新配置"""
async with AsyncSessionLocal() as db:
try:
agent = await self._get_agent_by_name(db, agent_name)
if not agent:
raise AgentNotFoundError(agent_name)
for key, value in configs.items():
config_value = self._wrap_value(value)
stmt = select(AgentConfigModel).where(
AgentConfigModel.agent_id == agent.id,
AgentConfigModel.config_key == key,
)
result = await db.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
existing.config_value = config_value
if updated_by:
existing.updated_by = uuid.UUID(updated_by)
else:
new_config = AgentConfigModel(
agent_id=agent.id,
config_key=key,
config_value=config_value,
updated_by=uuid.UUID(updated_by) if updated_by else None,
)
db.add(new_config)
await db.commit()
logger.info(
f"Bulk config update ({len(configs)} keys) for agent '{agent_name}'"
)
except AgentNotFoundError:
raise
except Exception as e:
await db.rollback()
logger.error(f"Failed to bulk update config for agent '{agent_name}': {e}")
raise
async def get_config_history(
self,
agent_name: str,
key: str | None = None,
) -> list[dict]:
"""
获取配置变更历史。
注意:当前模型不保留历史版本,此方法返回当前配置快照。
未来可扩展 config_versions 表来存储完整历史。
"""
async with AsyncSessionLocal() as db:
agent = await self._get_agent_by_name(db, agent_name)
if not agent:
raise AgentNotFoundError(agent_name)
stmt = select(AgentConfigModel).where(
AgentConfigModel.agent_id == agent.id,
)
if key:
stmt = stmt.where(AgentConfigModel.config_key == key)
result = await db.execute(stmt)
configs = result.scalars().all()
return [
{
"id": str(cfg.id),
"agent_name": agent_name,
"config_key": cfg.config_key,
"config_value": cfg.config_value,
"description": cfg.description,
"updated_at": cfg.updated_at.isoformat() if cfg.updated_at else None,
"updated_by": str(cfg.updated_by) if cfg.updated_by else None,
}
for cfg in configs
]
async def _get_agent_by_name(
self, db: AsyncSession, agent_name: str
) -> AgentRegistryModel | None:
"""根据名称查找 Agent"""
stmt = select(AgentRegistryModel).where(
AgentRegistryModel.name == agent_name
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
def _wrap_value(self, value: Any) -> dict:
"""将任意值包装为 JSONB 兼容的 dict"""
if isinstance(value, dict):
return value
return {"value": value}