fischer-agentkit/src/agentkit/tools/base.py

142 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tool 抽象基类 - 统一工具接口"""
import time
from abc import ABC, abstractmethod
from typing import Any
import jsonschema
from agentkit.telemetry.tracing import start_span
from agentkit.telemetry.metrics import tool_duration_histogram
class ToolValidationError(Exception):
"""工具参数 schema 校验失败。
error_code:
- "tool_call_invalid" 类型不匹配(jsonschema.ValidationError.path 末段为字段名)
- "schema_mismatch" 必填缺失 / 结构性错误(默认兜底)
"""
def __init__(
self,
message: str,
*,
error_code: str = "schema_mismatch",
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message)
self.error_code = error_code
self.details = details or {}
class Tool(ABC):
"""工具抽象基类
所有工具FunctionTool, AgentTool, MCPTool的统一接口。
"""
def __init__(
self,
name: str,
description: str,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
self.name = name
self.description = description
self.input_schema = input_schema
self.output_schema = output_schema
self.version = version
self.tags = tags or []
@abstractmethod
async def execute(self, **kwargs) -> dict:
"""执行工具,返回结果 dict"""
...
async def before_execute(self, **kwargs) -> None:
"""执行前钩子"""
pass
async def after_execute(self, result: dict, **kwargs) -> None:
"""执行后钩子"""
pass
async def on_error(self, error: Exception, **kwargs) -> None:
"""错误钩子"""
pass
async def safe_execute(self, **kwargs) -> dict:
"""带钩子的安全执行"""
_span_cm = start_span(
"tool.execute",
attributes={"tool.name": self.name},
)
_span = _span_cm.__enter__()
_start = time.monotonic()
try:
await self.before_execute(**kwargs)
self._validate_input(kwargs)
result = await self.execute(**kwargs)
await self.after_execute(result, **kwargs)
_duration_ms = int((time.monotonic() - _start) * 1000)
if _span is not None:
_span.set_attribute("tool.duration_ms", _duration_ms)
_span.set_attribute("tool.result.success", True)
tool_duration_histogram().record(_duration_ms, {"tool.name": self.name})
return result
except Exception as e:
_duration_ms = int((time.monotonic() - _start) * 1000)
if _span is not None:
_span.set_attribute("tool.duration_ms", _duration_ms)
_span.set_attribute("tool.result.success", False)
tool_duration_histogram().record(_duration_ms, {"tool.name": self.name})
await self.on_error(e, **kwargs)
raise
finally:
_span_cm.__exit__(None, None, None)
def _validate_input(self, kwargs: dict[str, Any]) -> None:
"""校验 kwargs 是否符合 self.input_schema。
- input_schema=None → 跳过(向后兼容,旧工具无 schema)
- 类型不匹配 → error_code="tool_call_invalid"
- 必填缺失 / 结构性错误 → error_code="schema_mismatch"
"""
if self.input_schema is None:
return
try:
jsonschema.validate(instance=kwargs, schema=self.input_schema)
except jsonschema.ValidationError as e:
field_path = ".".join(str(p) for p in e.absolute_path) or "<root>"
# required 缺失走 schema_mismatch;类型不符走 tool_call_invalid
if e.validator == "required":
code = "schema_mismatch"
else:
code = "tool_call_invalid"
raise ToolValidationError(
f"Tool '{self.name}' argument validation failed: {e.message}",
error_code=code,
details={
"field": field_path,
"validator": e.validator,
"schema_path": list(e.absolute_schema_path),
},
) from e
def to_dict(self) -> dict:
return {
"name": self.name,
"description": self.description,
"input_schema": self.input_schema,
"output_schema": self.output_schema,
"version": self.version,
"tags": self.tags,
}
def __repr__(self) -> str:
return f"<{type(self).__name__} name={self.name!r} version={self.version}>"