feat(agent): Wave 1 quick wins (G1/G2/G3/G8) + review fixes #4
|
|
@ -18,7 +18,7 @@ from agentkit.core.exceptions import LoopDetectedError, TaskCancelledError, Task
|
|||
from agentkit.core.protocol import CancellationToken
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMResponse
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.base import Tool, ToolValidationError
|
||||
from agentkit.telemetry.tracing import start_span, _OTEL_AVAILABLE
|
||||
from agentkit.telemetry.metrics import (
|
||||
agent_request_counter,
|
||||
|
|
@ -1910,6 +1910,15 @@ class ReActEngine:
|
|||
try:
|
||||
result = await tool.safe_execute(**clean_args)
|
||||
return result
|
||||
except ToolValidationError as e:
|
||||
# 保留类型化错误码,不被通用 except 平坦化为字符串
|
||||
error_msg = f"Tool '{tool_name}' schema validation failed: {e}"
|
||||
logger.warning(error_msg)
|
||||
return {
|
||||
"error": str(e),
|
||||
"error_code": e.error_code,
|
||||
"details": e.details,
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = f"Tool '{tool_name}' execution failed: {e}"
|
||||
logger.warning(error_msg)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,32 @@ import time
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
|
||||
from agentkit.telemetry.tracing import start_span
|
||||
from agentkit.telemetry.metrics import tool_duration_histogram
|
||||
|
||||
|
||||
class ToolValidationError(Exception):
|
||||
"""工具参数 schema 校验失败。
|
||||
|
||||
error_code:
|
||||
- "tool_call_invalid" 类型不匹配(jsonschema.ValidationError.path 末段为字段名)
|
||||
- "schema_mismatch" 必填缺失 / 结构性错误(默认兜底)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
error_code: str = "schema_mismatch",
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.error_code = error_code
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""工具抽象基类
|
||||
|
||||
|
|
@ -57,6 +79,7 @@ class Tool(ABC):
|
|||
_start = time.monotonic()
|
||||
try:
|
||||
await self.before_execute(**kwargs)
|
||||
self._validate_input(kwargs)
|
||||
result = await self.execute(**kwargs)
|
||||
await self.after_execute(result, **kwargs)
|
||||
_duration_ms = int((time.monotonic() - _start) * 1000)
|
||||
|
|
@ -76,6 +99,34 @@ class Tool(ABC):
|
|||
finally:
|
||||
_span_cm.__exit__(None, None, None)
|
||||
|
||||
def _validate_input(self, kwargs: dict[str, Any]) -> None:
|
||||
"""校验 kwargs 是否符合 self.input_schema。
|
||||
|
||||
- input_schema=None → 跳过(向后兼容,旧工具无 schema)
|
||||
- 类型不匹配 → error_code="tool_call_invalid"
|
||||
- 必填缺失 / 结构性错误 → error_code="schema_mismatch"
|
||||
"""
|
||||
if self.input_schema is None:
|
||||
return
|
||||
try:
|
||||
jsonschema.validate(instance=kwargs, schema=self.input_schema)
|
||||
except jsonschema.ValidationError as e:
|
||||
field_path = ".".join(str(p) for p in e.absolute_path) or "<root>"
|
||||
# required 缺失走 schema_mismatch;类型不符走 tool_call_invalid
|
||||
if e.validator == "required":
|
||||
code = "schema_mismatch"
|
||||
else:
|
||||
code = "tool_call_invalid"
|
||||
raise ToolValidationError(
|
||||
f"Tool '{self.name}' argument validation failed: {e.message}",
|
||||
error_code=code,
|
||||
details={
|
||||
"field": field_path,
|
||||
"validator": e.validator,
|
||||
"schema_path": list(e.absolute_schema_path),
|
||||
},
|
||||
) from e
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"name": self.name,
|
||||
|
|
|
|||
|
|
@ -50,6 +50,13 @@ class FunctionTool(Tool):
|
|||
for param_name, param in sig.parameters.items():
|
||||
if param_name in ("self", "cls"):
|
||||
continue
|
||||
# ponytail: VAR_KEYWORD(**kwargs)/VAR_POSITIONAL(*args) 是 catch-all,
|
||||
# 不是具体参数,不进 schema。否则 schema 校验会要求 kwargs 字段必填。
|
||||
if param.kind in (
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
):
|
||||
continue
|
||||
|
||||
param_type = "string"
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,124 @@
|
|||
"""U1 / G3 工具调用 schema 校验测试。
|
||||
|
||||
覆盖 R8-R10:
|
||||
- R8 schema 校验在 execute 前运行
|
||||
- R9 类型化错误码 (tool_call_invalid / schema_mismatch)
|
||||
- R10 _execute_tool 捕获后回灌 conversation(结构化 dict 保留 error_code)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.base import Tool, ToolValidationError
|
||||
|
||||
|
||||
class _StubTool(Tool):
|
||||
"""测试用 stub,记录调用与参数。"""
|
||||
|
||||
def __init__(self, *, schema=None, payload=None):
|
||||
super().__init__(name="stub", description="stub", input_schema=schema)
|
||||
self.calls: list[dict] = []
|
||||
self._payload = payload or {"ok": True}
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
self.calls.append(kwargs)
|
||||
return self._payload
|
||||
|
||||
|
||||
_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {"count": {"type": "integer"}},
|
||||
"required": ["count"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
|
||||
# ---- R8 happy path ----
|
||||
|
||||
|
||||
async def test_schema_valid_passes_through_to_execute():
|
||||
tool = _StubTool(schema=_SCHEMA)
|
||||
result = await tool.safe_execute(count=5)
|
||||
assert result == {"ok": True}
|
||||
assert tool.calls == [{"count": 5}]
|
||||
|
||||
|
||||
# ---- R9 backward compat: input_schema=None skips ----
|
||||
|
||||
|
||||
async def test_none_schema_skips_validation():
|
||||
tool = _StubTool(schema=None)
|
||||
await tool.safe_execute(anything="ok", even_invalid_types=123)
|
||||
assert tool.calls == [{"anything": "ok", "even_invalid_types": 123}]
|
||||
|
||||
|
||||
# ---- R9 type mismatch → tool_call_invalid ----
|
||||
|
||||
|
||||
async def test_type_mismatch_raises_tool_call_invalid():
|
||||
tool = _StubTool(schema=_SCHEMA)
|
||||
with pytest.raises(ToolValidationError) as ei:
|
||||
await tool.safe_execute(count="abc")
|
||||
assert ei.value.error_code == "tool_call_invalid"
|
||||
assert ei.value.details["field"] == "count"
|
||||
assert tool.calls == [] # execute not called
|
||||
|
||||
|
||||
# ---- R9 missing required → schema_mismatch ----
|
||||
|
||||
|
||||
async def test_missing_required_raises_schema_mismatch():
|
||||
tool = _StubTool(schema=_SCHEMA)
|
||||
with pytest.raises(ToolValidationError) as ei:
|
||||
await tool.safe_execute()
|
||||
assert ei.value.error_code == "schema_mismatch"
|
||||
assert ei.value.details["validator"] == "required"
|
||||
assert tool.calls == []
|
||||
|
||||
|
||||
# ---- R10 integration: _execute_tool catches and returns structured dict ----
|
||||
|
||||
|
||||
async def test_execute_tool_catches_validation_error_returns_structured_dict():
|
||||
"""ReActEngine._execute_tool 应优先捕获 ToolValidationError 并保留 error_code。"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = _StubTool(schema=_SCHEMA)
|
||||
engine = ReActEngine.__new__(ReActEngine) # 绕过 __init__
|
||||
result = await engine._execute_tool("stub", {"count": "not-int"}, [tool])
|
||||
|
||||
assert "error" in result
|
||||
assert result["error_code"] == "tool_call_invalid"
|
||||
assert "details" in result
|
||||
assert tool.calls == [] # execute 未执行
|
||||
|
||||
|
||||
async def test_execute_tool_catches_missing_required_returns_schema_mismatch():
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = _StubTool(schema=_SCHEMA)
|
||||
engine = ReActEngine.__new__(ReActEngine)
|
||||
result = await engine._execute_tool("stub", {}, [tool])
|
||||
|
||||
assert result["error_code"] == "schema_mismatch"
|
||||
assert result["details"]["validator"] == "required"
|
||||
|
||||
|
||||
# ---- R10 self-check: structured error string survives into tool message ----
|
||||
|
||||
|
||||
async def test_structured_error_dict_str_includes_error_code_for_llm_self_correction():
|
||||
"""_build_tool_result_message 把 dict 转 str(content),LLM 看到文本含 error_code。"""
|
||||
from agentkit.core.react import ReActEngine
|
||||
|
||||
tool = _StubTool(schema=_SCHEMA)
|
||||
engine = ReActEngine.__new__(ReActEngine)
|
||||
result = await engine._execute_tool("stub", {"count": "bad"}, [tool])
|
||||
msg = await engine._build_tool_result_message(
|
||||
tool_call_id="t1",
|
||||
result=result,
|
||||
)
|
||||
assert msg["role"] == "tool"
|
||||
assert msg["tool_call_id"] == "t1"
|
||||
assert "tool_call_invalid" in msg["content"]
|
||||
Loading…
Reference in New Issue