142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
"""Tool 抽象基类 - 统一工具接口"""
|
||
|
||
import time
|
||
from abc import ABC, abstractmethod
|
||
from typing import Any
|
||
|
||
import jsonschema
|
||
|
||
from agentkit.telemetry.tracing import start_span
|
||
from agentkit.telemetry.metrics import tool_duration_histogram
|
||
|
||
|
||
class ToolValidationError(Exception):
|
||
"""工具参数 schema 校验失败。
|
||
|
||
error_code:
|
||
- "tool_call_invalid" 类型不匹配(jsonschema.ValidationError.path 末段为字段名)
|
||
- "schema_mismatch" 必填缺失 / 结构性错误(默认兜底)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
message: str,
|
||
*,
|
||
error_code: str = "schema_mismatch",
|
||
details: dict[str, Any] | None = None,
|
||
) -> None:
|
||
super().__init__(message)
|
||
self.error_code = error_code
|
||
self.details = details or {}
|
||
|
||
|
||
class Tool(ABC):
|
||
"""工具抽象基类
|
||
|
||
所有工具(FunctionTool, AgentTool, MCPTool)的统一接口。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str,
|
||
description: str,
|
||
input_schema: dict[str, Any] | None = None,
|
||
output_schema: dict[str, Any] | None = None,
|
||
version: str = "1.0.0",
|
||
tags: list[str] | None = None,
|
||
):
|
||
self.name = name
|
||
self.description = description
|
||
self.input_schema = input_schema
|
||
self.output_schema = output_schema
|
||
self.version = version
|
||
self.tags = tags or []
|
||
|
||
@abstractmethod
|
||
async def execute(self, **kwargs) -> dict:
|
||
"""执行工具,返回结果 dict"""
|
||
...
|
||
|
||
async def before_execute(self, **kwargs) -> None:
|
||
"""执行前钩子"""
|
||
pass
|
||
|
||
async def after_execute(self, result: dict, **kwargs) -> None:
|
||
"""执行后钩子"""
|
||
pass
|
||
|
||
async def on_error(self, error: Exception, **kwargs) -> None:
|
||
"""错误钩子"""
|
||
pass
|
||
|
||
async def safe_execute(self, **kwargs) -> dict:
|
||
"""带钩子的安全执行"""
|
||
_span_cm = start_span(
|
||
"tool.execute",
|
||
attributes={"tool.name": self.name},
|
||
)
|
||
_span = _span_cm.__enter__()
|
||
_start = time.monotonic()
|
||
try:
|
||
await self.before_execute(**kwargs)
|
||
self._validate_input(kwargs)
|
||
result = await self.execute(**kwargs)
|
||
await self.after_execute(result, **kwargs)
|
||
_duration_ms = int((time.monotonic() - _start) * 1000)
|
||
if _span is not None:
|
||
_span.set_attribute("tool.duration_ms", _duration_ms)
|
||
_span.set_attribute("tool.result.success", True)
|
||
tool_duration_histogram().record(_duration_ms, {"tool.name": self.name})
|
||
return result
|
||
except Exception as e:
|
||
_duration_ms = int((time.monotonic() - _start) * 1000)
|
||
if _span is not None:
|
||
_span.set_attribute("tool.duration_ms", _duration_ms)
|
||
_span.set_attribute("tool.result.success", False)
|
||
tool_duration_histogram().record(_duration_ms, {"tool.name": self.name})
|
||
await self.on_error(e, **kwargs)
|
||
raise
|
||
finally:
|
||
_span_cm.__exit__(None, None, None)
|
||
|
||
def _validate_input(self, kwargs: dict[str, Any]) -> None:
|
||
"""校验 kwargs 是否符合 self.input_schema。
|
||
|
||
- input_schema=None → 跳过(向后兼容,旧工具无 schema)
|
||
- 类型不匹配 → error_code="tool_call_invalid"
|
||
- 必填缺失 / 结构性错误 → error_code="schema_mismatch"
|
||
"""
|
||
if self.input_schema is None:
|
||
return
|
||
try:
|
||
jsonschema.validate(instance=kwargs, schema=self.input_schema)
|
||
except jsonschema.ValidationError as e:
|
||
field_path = ".".join(str(p) for p in e.absolute_path) or "<root>"
|
||
# required 缺失走 schema_mismatch;类型不符走 tool_call_invalid
|
||
if e.validator == "required":
|
||
code = "schema_mismatch"
|
||
else:
|
||
code = "tool_call_invalid"
|
||
raise ToolValidationError(
|
||
f"Tool '{self.name}' argument validation failed: {e.message}",
|
||
error_code=code,
|
||
details={
|
||
"field": field_path,
|
||
"validator": e.validator,
|
||
"schema_path": list(e.absolute_schema_path),
|
||
},
|
||
) from e
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"name": self.name,
|
||
"description": self.description,
|
||
"input_schema": self.input_schema,
|
||
"output_schema": self.output_schema,
|
||
"version": self.version,
|
||
"tags": self.tags,
|
||
}
|
||
|
||
def __repr__(self) -> str:
|
||
return f"<{type(self).__name__} name={self.name!r} version={self.version}>"
|