From c66a7773b546096b46d7e80957c3d6b66491c69e Mon Sep 17 00:00:00 2001 From: chiguyong Date: Mon, 29 Jun 2026 20:34:14 +0800 Subject: [PATCH] =?UTF-8?q?feat(U1):=20G3=20=E5=B7=A5=E5=85=B7=E8=B0=83?= =?UTF-8?q?=E7=94=A8=20schema=20=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - base.py 新增 ToolValidationError(error_code/details)与 _validate_input - safe_execute 在 execute 前用 jsonschema.validate 校验 kwargs - input_schema=None 跳过校验保持向后兼容 - _execute_tool 优先捕获 ToolValidationError 保留 error_code - function_tool._infer_schema 修复 VAR_KEYWORD/VAR_POSITIONAL 误入 schema - test_tool_schema_validation.py 覆盖 R8-R10 --- src/agentkit/core/react.py | 11 +- src/agentkit/tools/base.py | 51 +++++++++ src/agentkit/tools/function_tool.py | 7 ++ tests/unit/test_tool_schema_validation.py | 124 ++++++++++++++++++++++ 4 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_tool_schema_validation.py diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index b0bb29b..5e1c50c 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -18,7 +18,7 @@ from agentkit.core.exceptions import LoopDetectedError, TaskCancelledError, Task from agentkit.core.protocol import CancellationToken from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMResponse -from agentkit.tools.base import Tool +from agentkit.tools.base import Tool, ToolValidationError from agentkit.telemetry.tracing import start_span, _OTEL_AVAILABLE from agentkit.telemetry.metrics import ( agent_request_counter, @@ -1910,6 +1910,15 @@ class ReActEngine: try: result = await tool.safe_execute(**clean_args) return result + except ToolValidationError as e: + # 保留类型化错误码,不被通用 except 平坦化为字符串 + error_msg = f"Tool '{tool_name}' schema validation failed: {e}" + logger.warning(error_msg) + return { + "error": str(e), + "error_code": e.error_code, + "details": e.details, + } except Exception as e: error_msg = f"Tool '{tool_name}' execution failed: {e}" logger.warning(error_msg) diff --git a/src/agentkit/tools/base.py b/src/agentkit/tools/base.py index 79a1706..6829a2b 100644 --- a/src/agentkit/tools/base.py +++ b/src/agentkit/tools/base.py @@ -4,10 +4,32 @@ import time from abc import ABC, abstractmethod from typing import Any +import jsonschema + from agentkit.telemetry.tracing import start_span from agentkit.telemetry.metrics import tool_duration_histogram +class ToolValidationError(Exception): + """工具参数 schema 校验失败。 + + error_code: + - "tool_call_invalid" 类型不匹配(jsonschema.ValidationError.path 末段为字段名) + - "schema_mismatch" 必填缺失 / 结构性错误(默认兜底) + """ + + def __init__( + self, + message: str, + *, + error_code: str = "schema_mismatch", + details: dict[str, Any] | None = None, + ) -> None: + super().__init__(message) + self.error_code = error_code + self.details = details or {} + + class Tool(ABC): """工具抽象基类 @@ -57,6 +79,7 @@ class Tool(ABC): _start = time.monotonic() try: await self.before_execute(**kwargs) + self._validate_input(kwargs) result = await self.execute(**kwargs) await self.after_execute(result, **kwargs) _duration_ms = int((time.monotonic() - _start) * 1000) @@ -76,6 +99,34 @@ class Tool(ABC): finally: _span_cm.__exit__(None, None, None) + def _validate_input(self, kwargs: dict[str, Any]) -> None: + """校验 kwargs 是否符合 self.input_schema。 + + - input_schema=None → 跳过(向后兼容,旧工具无 schema) + - 类型不匹配 → error_code="tool_call_invalid" + - 必填缺失 / 结构性错误 → error_code="schema_mismatch" + """ + if self.input_schema is None: + return + try: + jsonschema.validate(instance=kwargs, schema=self.input_schema) + except jsonschema.ValidationError as e: + field_path = ".".join(str(p) for p in e.absolute_path) or "" + # required 缺失走 schema_mismatch;类型不符走 tool_call_invalid + if e.validator == "required": + code = "schema_mismatch" + else: + code = "tool_call_invalid" + raise ToolValidationError( + f"Tool '{self.name}' argument validation failed: {e.message}", + error_code=code, + details={ + "field": field_path, + "validator": e.validator, + "schema_path": list(e.absolute_schema_path), + }, + ) from e + def to_dict(self) -> dict: return { "name": self.name, diff --git a/src/agentkit/tools/function_tool.py b/src/agentkit/tools/function_tool.py index 92570a9..c1086d9 100644 --- a/src/agentkit/tools/function_tool.py +++ b/src/agentkit/tools/function_tool.py @@ -50,6 +50,13 @@ class FunctionTool(Tool): for param_name, param in sig.parameters.items(): if param_name in ("self", "cls"): continue + # ponytail: VAR_KEYWORD(**kwargs)/VAR_POSITIONAL(*args) 是 catch-all, + # 不是具体参数,不进 schema。否则 schema 校验会要求 kwargs 字段必填。 + if param.kind in ( + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + ): + continue param_type = "string" if param.annotation != inspect.Parameter.empty: diff --git a/tests/unit/test_tool_schema_validation.py b/tests/unit/test_tool_schema_validation.py new file mode 100644 index 0000000..f7b9e89 --- /dev/null +++ b/tests/unit/test_tool_schema_validation.py @@ -0,0 +1,124 @@ +"""U1 / G3 工具调用 schema 校验测试。 + +覆盖 R8-R10: +- R8 schema 校验在 execute 前运行 +- R9 类型化错误码 (tool_call_invalid / schema_mismatch) +- R10 _execute_tool 捕获后回灌 conversation(结构化 dict 保留 error_code) +""" + +from __future__ import annotations + +import pytest + +from agentkit.tools.base import Tool, ToolValidationError + + +class _StubTool(Tool): + """测试用 stub,记录调用与参数。""" + + def __init__(self, *, schema=None, payload=None): + super().__init__(name="stub", description="stub", input_schema=schema) + self.calls: list[dict] = [] + self._payload = payload or {"ok": True} + + async def execute(self, **kwargs) -> dict: + self.calls.append(kwargs) + return self._payload + + +_SCHEMA = { + "type": "object", + "properties": {"count": {"type": "integer"}}, + "required": ["count"], + "additionalProperties": False, +} + + +# ---- R8 happy path ---- + + +async def test_schema_valid_passes_through_to_execute(): + tool = _StubTool(schema=_SCHEMA) + result = await tool.safe_execute(count=5) + assert result == {"ok": True} + assert tool.calls == [{"count": 5}] + + +# ---- R9 backward compat: input_schema=None skips ---- + + +async def test_none_schema_skips_validation(): + tool = _StubTool(schema=None) + await tool.safe_execute(anything="ok", even_invalid_types=123) + assert tool.calls == [{"anything": "ok", "even_invalid_types": 123}] + + +# ---- R9 type mismatch → tool_call_invalid ---- + + +async def test_type_mismatch_raises_tool_call_invalid(): + tool = _StubTool(schema=_SCHEMA) + with pytest.raises(ToolValidationError) as ei: + await tool.safe_execute(count="abc") + assert ei.value.error_code == "tool_call_invalid" + assert ei.value.details["field"] == "count" + assert tool.calls == [] # execute not called + + +# ---- R9 missing required → schema_mismatch ---- + + +async def test_missing_required_raises_schema_mismatch(): + tool = _StubTool(schema=_SCHEMA) + with pytest.raises(ToolValidationError) as ei: + await tool.safe_execute() + assert ei.value.error_code == "schema_mismatch" + assert ei.value.details["validator"] == "required" + assert tool.calls == [] + + +# ---- R10 integration: _execute_tool catches and returns structured dict ---- + + +async def test_execute_tool_catches_validation_error_returns_structured_dict(): + """ReActEngine._execute_tool 应优先捕获 ToolValidationError 并保留 error_code。""" + from agentkit.core.react import ReActEngine + + tool = _StubTool(schema=_SCHEMA) + engine = ReActEngine.__new__(ReActEngine) # 绕过 __init__ + result = await engine._execute_tool("stub", {"count": "not-int"}, [tool]) + + assert "error" in result + assert result["error_code"] == "tool_call_invalid" + assert "details" in result + assert tool.calls == [] # execute 未执行 + + +async def test_execute_tool_catches_missing_required_returns_schema_mismatch(): + from agentkit.core.react import ReActEngine + + tool = _StubTool(schema=_SCHEMA) + engine = ReActEngine.__new__(ReActEngine) + result = await engine._execute_tool("stub", {}, [tool]) + + assert result["error_code"] == "schema_mismatch" + assert result["details"]["validator"] == "required" + + +# ---- R10 self-check: structured error string survives into tool message ---- + + +async def test_structured_error_dict_str_includes_error_code_for_llm_self_correction(): + """_build_tool_result_message 把 dict 转 str(content),LLM 看到文本含 error_code。""" + from agentkit.core.react import ReActEngine + + tool = _StubTool(schema=_SCHEMA) + engine = ReActEngine.__new__(ReActEngine) + result = await engine._execute_tool("stub", {"count": "bad"}, [tool]) + msg = await engine._build_tool_result_message( + tool_call_id="t1", + result=result, + ) + assert msg["role"] == "tool" + assert msg["tool_call_id"] == "t1" + assert "tool_call_invalid" in msg["content"]