72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
"""ToolRegistry - 工具注册中心"""
|
|
|
|
import logging
|
|
|
|
from agentkit.core.exceptions import ToolNotFoundError
|
|
from agentkit.tools.base import Tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ToolRegistry:
|
|
"""工具注册中心,管理工具的注册、发现、版本"""
|
|
|
|
def __init__(self):
|
|
self._tools: dict[str, dict[str, Tool]] = {} # name -> {version -> tool}
|
|
|
|
def register(self, tool: Tool) -> "ToolRegistry":
|
|
"""注册工具"""
|
|
if tool.name not in self._tools:
|
|
self._tools[tool.name] = {}
|
|
self._tools[tool.name][tool.version] = tool
|
|
logger.info(f"Tool '{tool.name}' v{tool.version} registered")
|
|
return self
|
|
|
|
def unregister(self, name: str, version: str | None = None) -> None:
|
|
"""注销工具"""
|
|
if name not in self._tools:
|
|
return
|
|
if version:
|
|
self._tools[name].pop(version, None)
|
|
if not self._tools[name]:
|
|
del self._tools[name]
|
|
else:
|
|
del self._tools[name]
|
|
|
|
def get(self, name: str, version: str | None = None) -> Tool:
|
|
"""获取工具(默认返回最新版本)"""
|
|
if name not in self._tools:
|
|
raise ToolNotFoundError(name)
|
|
|
|
versions = self._tools[name]
|
|
if version:
|
|
if version not in versions:
|
|
raise ToolNotFoundError(f"{name}@{version}")
|
|
return versions[version]
|
|
|
|
# 返回最新版本
|
|
latest = sorted(versions.keys())[-1]
|
|
return versions[latest]
|
|
|
|
def list_tools(self, tag: str | None = None) -> list[Tool]:
|
|
"""列出所有工具(最新版本),可按标签过滤"""
|
|
result = []
|
|
for name, versions in self._tools.items():
|
|
latest = sorted(versions.keys())[-1]
|
|
tool = versions[latest]
|
|
if tag is None or tag in tool.tags:
|
|
result.append(tool)
|
|
return result
|
|
|
|
def list_all_versions(self, name: str) -> list[Tool]:
|
|
"""列出指定工具的所有版本"""
|
|
if name not in self._tools:
|
|
return []
|
|
return list(self._tools[name].values())
|
|
|
|
def has_tool(self, name: str) -> bool:
|
|
return name in self._tools
|
|
|
|
def clear(self) -> None:
|
|
self._tools.clear()
|