fischer-agentkit/tests/unit/test_tool_schema_validation.py

125 lines
3.9 KiB
Python

"""U1 / G3 工具调用 schema 校验测试。
覆盖 R8-R10:
- R8 schema 校验在 execute 前运行
- R9 类型化错误码 (tool_call_invalid / schema_mismatch)
- R10 _execute_tool 捕获后回灌 conversation(结构化 dict 保留 error_code)
"""
from __future__ import annotations
import pytest
from agentkit.tools.base import Tool, ToolValidationError
class _StubTool(Tool):
"""测试用 stub,记录调用与参数。"""
def __init__(self, *, schema=None, payload=None):
super().__init__(name="stub", description="stub", input_schema=schema)
self.calls: list[dict] = []
self._payload = payload or {"ok": True}
async def execute(self, **kwargs) -> dict:
self.calls.append(kwargs)
return self._payload
_SCHEMA = {
"type": "object",
"properties": {"count": {"type": "integer"}},
"required": ["count"],
"additionalProperties": False,
}
# ---- R8 happy path ----
async def test_schema_valid_passes_through_to_execute():
tool = _StubTool(schema=_SCHEMA)
result = await tool.safe_execute(count=5)
assert result == {"ok": True}
assert tool.calls == [{"count": 5}]
# ---- R9 backward compat: input_schema=None skips ----
async def test_none_schema_skips_validation():
tool = _StubTool(schema=None)
await tool.safe_execute(anything="ok", even_invalid_types=123)
assert tool.calls == [{"anything": "ok", "even_invalid_types": 123}]
# ---- R9 type mismatch → tool_call_invalid ----
async def test_type_mismatch_raises_tool_call_invalid():
tool = _StubTool(schema=_SCHEMA)
with pytest.raises(ToolValidationError) as ei:
await tool.safe_execute(count="abc")
assert ei.value.error_code == "tool_call_invalid"
assert ei.value.details["field"] == "count"
assert tool.calls == [] # execute not called
# ---- R9 missing required → schema_mismatch ----
async def test_missing_required_raises_schema_mismatch():
tool = _StubTool(schema=_SCHEMA)
with pytest.raises(ToolValidationError) as ei:
await tool.safe_execute()
assert ei.value.error_code == "schema_mismatch"
assert ei.value.details["validator"] == "required"
assert tool.calls == []
# ---- R10 integration: _execute_tool catches and returns structured dict ----
async def test_execute_tool_catches_validation_error_returns_structured_dict():
"""ReActEngine._execute_tool 应优先捕获 ToolValidationError 并保留 error_code。"""
from agentkit.core.react import ReActEngine
tool = _StubTool(schema=_SCHEMA)
engine = ReActEngine.__new__(ReActEngine) # 绕过 __init__
result = await engine._execute_tool("stub", {"count": "not-int"}, [tool])
assert "error" in result
assert result["error_code"] == "tool_call_invalid"
assert "details" in result
assert tool.calls == [] # execute 未执行
async def test_execute_tool_catches_missing_required_returns_schema_mismatch():
from agentkit.core.react import ReActEngine
tool = _StubTool(schema=_SCHEMA)
engine = ReActEngine.__new__(ReActEngine)
result = await engine._execute_tool("stub", {}, [tool])
assert result["error_code"] == "schema_mismatch"
assert result["details"]["validator"] == "required"
# ---- R10 self-check: structured error string survives into tool message ----
async def test_structured_error_dict_str_includes_error_code_for_llm_self_correction():
"""_build_tool_result_message 把 dict 转 str(content),LLM 看到文本含 error_code。"""
from agentkit.core.react import ReActEngine
tool = _StubTool(schema=_SCHEMA)
engine = ReActEngine.__new__(ReActEngine)
result = await engine._execute_tool("stub", {"count": "bad"}, [tool])
msg = await engine._build_tool_result_message(
tool_call_id="t1",
result=result,
)
assert msg["role"] == "tool"
assert msg["tool_call_id"] == "t1"
assert "tool_call_invalid" in msg["content"]