diff --git a/src/agentkit/core/react.py b/src/agentkit/core/react.py index 836c5ee..31f0f87 100644 --- a/src/agentkit/core/react.py +++ b/src/agentkit/core/react.py @@ -153,9 +153,12 @@ class ReActEngine: # Default core tools that always get full descriptions injected into the # prompt. ``tool_search`` is included so its full description is always # available to the LLM when tiered injection is active. + # U1: replaced the broken `write_file` placeholder (no real implementation — + # only `_FakeTool` stubs) with `str_replace_editor` (workspace-root confined + # create/str_replace/insert_at_line/view — see tools/str_replace_editor.py). _DEFAULT_CORE_TOOLS: tuple[str, ...] = ( "read_file", - "write_file", + "str_replace_editor", "bash", "search", "tool_search", diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index ea54dcf..b81018f 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -19,6 +19,7 @@ from agentkit.tools.web_search import WebSearchTool from agentkit.tools.builtin import RunTestsTool, ToolSearchTool from agentkit.tools.search import ToolSearchIndex from agentkit.tools.file_read import ReadFileTool +from agentkit.tools.str_replace_editor import StrReplaceEditorTool from agentkit.tools.advance_phase import AdvancePhaseTool # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor @@ -55,5 +56,6 @@ __all__ = [ "ParsedOutput", "ErrorType", "ReadFileTool", + "StrReplaceEditorTool", "AdvancePhaseTool", ] diff --git a/src/agentkit/tools/str_replace_editor.py b/src/agentkit/tools/str_replace_editor.py new file mode 100644 index 0000000..0ff8210 --- /dev/null +++ b/src/agentkit/tools/str_replace_editor.py @@ -0,0 +1,400 @@ +"""StrReplaceEditorTool — structured file editing with workspace-root security (U1, R1). + +Replaces the broken `write_file` placeholder (which had no real implementation — +only `_FakeTool` stubs in `cli/benchmark.py`). Provides four commands: + + - `create` write a new file (errors if it already exists — data-loss guard) + - `str_replace` exact-match anchor replace (anchor must be unique in the file) + - `insert_at_line` insert text at a 1-based line number (0 = prepend, > EOF = append) + - `view` read file with line numbers (needed so `str_replace` anchors + and `insert_at_line` targets can be discovered) + +Security model (file-system analog of the 6-layer terminal security paradigm in +`server/auth/terminal_security.py` — reject-by-default + prefix match): + + 1. Reject absolute paths (force relative interpretation against workspace root). + 2. Reject any ``..`` path component (path traversal). + 3. ``Path.resolve()`` follows symlinks, then ``relative_to(workspace_root)`` + rejects symlink escape and any residual traversal. + +Filesystem I/O is wrapped in ``asyncio.to_thread`` to avoid blocking the event loop. +""" + +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class StrReplaceEditorTool(Tool): + """Structured file editor with four commands and workspace-root confinement. + + Tool name ``str_replace_editor`` is registered in + ``core/react.py:_DEFAULT_CORE_TOOLS`` so its full description is always + injected into the LLM prompt (tiered description injection). + """ + + def __init__( + self, + workspace_root: str | Path | None = None, + name: str = "str_replace_editor", + description: str | None = None, + input_schema: dict[str, object] | None = None, + output_schema: dict[str, object] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + # Resolve once so later prefix checks compare against a stable, real + # directory (no symlink in the workspace root itself). + self._workspace_root: Path = Path(workspace_root or Path.cwd()).resolve() + super().__init__( + name=name, + description=description or self._default_description(), + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["io", "file", "edit"], + ) + + @staticmethod + def _default_description() -> str: + return ( + "Edit a file with structured commands. Paths are relative to the " + "workspace root (absolute paths and `..` traversal are rejected; " + "symlink escape is blocked). Commands: `create` (write a new file — " + "errors if it exists), `str_replace` (replace a unique exact-match " + "anchor), `insert_at_line` (insert text at a 1-based line; 0=prepend, " + ">EOF=append), `view` (read file with line numbers). Always `view` a " + "file first to get exact anchors and line numbers." + ) + + @staticmethod + def _default_input_schema() -> dict[str, object]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "enum": ["create", "str_replace", "insert_at_line", "view"], + "description": "The editing command to execute.", + }, + "path": { + "type": "string", + "description": ( + "Relative path to the file within the workspace root. " + "Absolute paths and `..` components are rejected." + ), + }, + "file_text": { + "type": "string", + "description": "Required for `create`: full content of the new file.", + }, + "old_str": { + "type": "string", + "description": ( + "Required for `str_replace`: exact text to find (whitespace " + "and indentation must match). Must occur exactly once." + ), + }, + "new_str": { + "type": "string", + "description": ( + "Required for `str_replace` and `insert_at_line`: the " + "replacement / insertion text (may be multi-line)." + ), + }, + "insert_line": { + "type": "integer", + "minimum": 0, + "description": ( + "Required for `insert_at_line`: 1-based line number to insert " + "BEFORE. 0 = prepend before line 1; greater than the file's " + "line count = append at end." + ), + }, + "start_line": { + "type": "integer", + "minimum": 1, + "description": "Optional for `view`: 1-based start line (inclusive).", + }, + "end_line": { + "type": "integer", + "minimum": 1, + "description": "Optional for `view`: 1-based end line (inclusive).", + }, + }, + "required": ["command", "path"], + "additionalProperties": False, + } + + @staticmethod + def _default_output_schema() -> dict[str, object]: + return { + "type": "object", + "properties": { + "command": {"type": "string"}, + "path": {"type": "string"}, + "content": {"type": "string"}, + "start_line": {"type": "integer"}, + "end_line": {"type": "integer"}, + "total_lines": {"type": "integer"}, + "is_error": {"type": "boolean"}, + "error": {"type": "string"}, + "note": {"type": "string"}, + }, + } + + # ── path security ───────────────────────────────────────────────── + + def _resolve_within_workspace(self, raw_path: str) -> Path | None: + """Resolve ``raw_path`` and verify it stays within the workspace root. + + Returns the resolved absolute Path on success, or ``None`` if the path + is absolute, contains a ``..`` component, or resolves outside the + workspace root (path traversal or symlink escape). ``Path.resolve()`` + follows symlinks, so a symlink pointing outside the workspace resolves + to an outside path and fails the ``relative_to`` check. + """ + if not isinstance(raw_path, str) or not raw_path: + return None + p = Path(raw_path) + if p.is_absolute(): + return None # layer 1: force relative interpretation + if ".." in p.parts: + return None # layer 2: reject path traversal + resolved = (self._workspace_root / raw_path).resolve() + try: + resolved.relative_to(self._workspace_root) # layer 3: symlink escape + except ValueError: + return None + return resolved + + # ── execute ──────────────────────────────────────────────────────── + + async def execute(self, **kwargs) -> dict[str, object]: + command = kwargs.get("command") + raw_path = kwargs.get("path") + if command not in ("create", "str_replace", "insert_at_line", "view"): + return self._error( + f"Unknown command {command!r}; expected one of " + "create/str_replace/insert_at_line/view", + path=raw_path if isinstance(raw_path, str) else None, + ) + if not isinstance(raw_path, str) or not raw_path: + return self._error("`path` is required and must be a non-empty string") + + path = self._resolve_within_workspace(raw_path) + if path is None: + return self._error( + f"Path {raw_path!r} is rejected: absolute paths, `..` traversal, " + f"and symlink escape outside the workspace root " + f"({self._workspace_root}) are not allowed.", + path=raw_path, + ) + + if command == "create": + return await self._cmd_create(path, kwargs) + if command == "str_replace": + return await self._cmd_str_replace(path, kwargs) + if command == "insert_at_line": + return await self._cmd_insert_at_line(path, kwargs) + return await self._cmd_view(path, kwargs) + + # ── commands ─────────────────────────────────────────────────────── + + async def _cmd_create(self, path: Path, kwargs: dict[str, object]) -> dict[str, object]: + file_text = kwargs.get("file_text") + if not isinstance(file_text, str): + return self._error("`file_text` is required for `create`", path=str(path)) + if path.exists(): + # Data-loss guard: refuse to overwrite. Use str_replace to edit an + # existing file, or delete it first via the shell tool. + return self._error( + f"File already exists (create refuses to overwrite): {path}. " + f"Use str_replace to edit it.", + path=str(path), + ) + return await asyncio.to_thread(self._write_file, path, file_text, "create") + + async def _cmd_str_replace(self, path: Path, kwargs: dict[str, object]) -> dict[str, object]: + old_str = kwargs.get("old_str") + new_str = kwargs.get("new_str") + if not isinstance(old_str, str) or old_str == "": + return self._error( + "`old_str` is required for `str_replace` and must be non-empty", + path=str(path), + ) + if not isinstance(new_str, str): + return self._error("`new_str` is required for `str_replace`", path=str(path)) + + read_result = await asyncio.to_thread(self._read_file, path) + if read_result["is_error"]: + return read_result + content = read_result["content"] + count = content.count(old_str) + if count == 0: + return self._error( + f"`old_str` anchor not found in {path}. Use `view` to inspect the " + f"exact text (whitespace and indentation must match).", + path=str(path), + ) + if count > 1: + return self._error( + f"`old_str` anchor is not unique: found {count} matches in {path}. " + f"Include more surrounding context so the anchor matches once.", + path=str(path), + ) + new_content = content.replace(old_str, new_str, 1) + return await asyncio.to_thread(self._write_file, path, new_content, "str_replace") + + async def _cmd_insert_at_line(self, path: Path, kwargs: dict[str, object]) -> dict[str, object]: + new_str = kwargs.get("new_str") + if not isinstance(new_str, str): + return self._error("`new_str` is required for `insert_at_line`", path=str(path)) + insert_line = kwargs.get("insert_line") + # bool is a subclass of int — exclude it explicitly. + if isinstance(insert_line, bool) or not isinstance(insert_line, int): + return self._error( + f"`insert_line` is required for `insert_at_line` and must be a " + f"non-negative integer, got {insert_line!r}", + path=str(path), + ) + if insert_line < 0: + return self._error(f"`insert_line` must be >= 0, got {insert_line}", path=str(path)) + + read_result = await asyncio.to_thread(self._read_file, path) + if read_result["is_error"]: + return read_result + content = read_result["content"] + lines = content.splitlines() + # 1-based line N → insert before it (index N-1). 0 → prepend (index 0). + # Beyond EOF → append (index len). splitlines drops a trailing newline, + # so EOF here means the last logical line. + idx = 0 if insert_line == 0 else insert_line - 1 + idx = max(0, min(idx, len(lines))) + new_lines = new_str.splitlines() if new_str != "" else [] + result_lines = lines[:idx] + new_lines + lines[idx:] + new_content = "\n".join(result_lines) + # Preserve a trailing newline that existed in the original (splitlines + # dropped it). ponytail: only the final newline is restored; rare + # double-trailing-newline files collapse to one on insert — acceptable + # for an editor on an LF-normalized repo. + if content.endswith("\n") and not new_content.endswith("\n"): + new_content += "\n" + return await asyncio.to_thread(self._write_file, path, new_content, "insert_at_line") + + async def _cmd_view(self, path: Path, kwargs: dict[str, object]) -> dict[str, object]: + start_line = kwargs.get("start_line") + end_line = kwargs.get("end_line") + if start_line is not None and ( + not isinstance(start_line, int) or isinstance(start_line, bool) or start_line < 1 + ): + return self._error( + f"`start_line` must be a positive integer, got {start_line!r}", + path=str(path), + ) + if end_line is not None and ( + not isinstance(end_line, int) or isinstance(end_line, bool) or end_line < 1 + ): + return self._error( + f"`end_line` must be a positive integer, got {end_line!r}", + path=str(path), + ) + if start_line is not None and end_line is not None and end_line < start_line: + return self._error( + f"`end_line` ({end_line}) must be >= `start_line` ({start_line})", + path=str(path), + ) + + read_result = await asyncio.to_thread(self._read_file, path) + if read_result["is_error"]: + return read_result + content = read_result["content"] + lines = content.splitlines() + total = len(lines) + if total == 0: + return { + "command": "view", + "path": str(path), + "content": "", + "start_line": 0, + "end_line": 0, + "total_lines": 0, + "is_error": False, + "note": "empty file", + } + s = max(1, start_line or 1) + e = min(total, end_line or total) + if s > total: + numbered = "" + note = f"range starts beyond EOF (file has {total} lines)" + else: + sliced = lines[s - 1 : e] + # cat -n style: right-aligned 1-based number + tab. ASCII only. + numbered = "\n".join(f"{i:>6}\t{line}" for i, line in enumerate(sliced, start=s)) + note = None + result: dict[str, object] = { + "command": "view", + "path": str(path), + "content": numbered, + "start_line": s, + "end_line": e, + "total_lines": total, + "is_error": False, + } + if note: + result["note"] = note + return result + + # ── blocking filesystem helpers (run via to_thread) ──────────────── + + def _read_file(self, path: Path) -> dict[str, object]: + if not path.exists(): + return self._error(f"File not found: {path}", path=str(path)) + if path.is_dir(): + return self._error(f"Path is a directory, not a file: {path}", path=str(path)) + try: + content = path.read_text(encoding="utf-8", errors="replace") + except PermissionError as e: + return self._error(f"Permission denied: {path}", path=str(path), detail=str(e)) + except OSError as e: + return self._error(f"Failed to read {path}: {e}", path=str(path)) + return {"content": content, "is_error": False} + + def _write_file(self, path: Path, content: str, command: str) -> dict[str, object]: + try: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8") + except PermissionError as e: + return self._error(f"Permission denied: {path}", path=str(path), detail=str(e)) + except OSError as e: + return self._error(f"Failed to write {path}: {e}", path=str(path)) + # ponytail: splitlines is O(n) per write; fine for editor-scale files + # (<1 MB). For VFS-scale writes, pass len(lines) from the caller instead. + return { + "command": command, + "path": str(path), + "content": content, + "total_lines": len(content.splitlines()), + "is_error": False, + "note": f"{command} succeeded", + } + + @staticmethod + def _error( + message: str, + *, + path: str | None = None, + detail: str | None = None, + ) -> dict[str, object]: + result: dict[str, object] = {"is_error": True, "error": message} + if path is not None: + result["path"] = path + if detail is not None: + result["detail"] = detail + return result diff --git a/tests/unit/test_str_replace_editor.py b/tests/unit/test_str_replace_editor.py new file mode 100644 index 0000000..16c7a79 --- /dev/null +++ b/tests/unit/test_str_replace_editor.py @@ -0,0 +1,421 @@ +"""Unit tests for StrReplaceEditorTool (U1, R1). + +Covers happy path, edge cases, error/failure paths, path-security rejection, +and the integration contract that the tool is registered as a default core +tool in ReActEngine and exported from the tools package. +""" + +from __future__ import annotations + +import os +import sys +from pathlib import Path + +import pytest + +from agentkit.tools.str_replace_editor import StrReplaceEditorTool + + +# ── fixtures ────────────────────────────────────────────────────────── + + +@pytest.fixture +def workspace(tmp_path: Path) -> Path: + """A clean workspace root directory for each test.""" + return tmp_path + + +@pytest.fixture +def tool(workspace: Path) -> StrReplaceEditorTool: + return StrReplaceEditorTool(workspace_root=workspace) + + +# ── happy path ──────────────────────────────────────────────────────── + + +async def test_create_writes_new_file(tool: StrReplaceEditorTool, workspace: Path) -> None: + result = await tool.execute(command="create", path="hello.py", file_text="print('hi')\n") + assert result["is_error"] is False + assert result["command"] == "create" + assert result["total_lines"] == 1 + assert (workspace / "hello.py").read_text() == "print('hi')\n" + + +async def test_view_returns_content_with_line_numbers( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + (workspace / "a.txt").write_text("alpha\nbeta\ngamma\n") + result = await tool.execute(command="view", path="a.txt") + assert result["is_error"] is False + assert result["total_lines"] == 3 + assert result["start_line"] == 1 + assert result["end_line"] == 3 + # cat -n style: right-aligned number + tab. + assert result["content"] == " 1\talpha\n 2\tbeta\n 3\tgamma" + + +async def test_str_replace_replaces_unique_anchor( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + (workspace / "f.txt").write_text("def foo():\n return 1\n") + result = await tool.execute( + command="str_replace", + path="f.txt", + old_str="return 1", + new_str="return 2", + ) + assert result["is_error"] is False + assert (workspace / "f.txt").read_text() == "def foo():\n return 2\n" + + +async def test_insert_at_line_inserts_in_middle( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + (workspace / "f.txt").write_text("line1\nline2\nline3\n") + result = await tool.execute( + command="insert_at_line", path="f.txt", insert_line=2, new_str="INSERTED" + ) + assert result["is_error"] is False + assert (workspace / "f.txt").read_text() == "line1\nINSERTED\nline2\nline3\n" + + +# ── edge cases ──────────────────────────────────────────────────────── + + +async def test_create_empty_file(tool: StrReplaceEditorTool, workspace: Path) -> None: + result = await tool.execute(command="create", path="empty.txt", file_text="") + assert result["is_error"] is False + assert result["total_lines"] == 0 + assert (workspace / "empty.txt").read_text() == "" + # view of an empty file reports total_lines=0 with a note. + view = await tool.execute(command="view", path="empty.txt") + assert view["is_error"] is False + assert view["total_lines"] == 0 + assert view["content"] == "" + assert view["note"] == "empty file" + + +async def test_str_replace_multiple_matches_is_error( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + (workspace / "f.txt").write_text("x\nx\n") + result = await tool.execute(command="str_replace", path="f.txt", old_str="x", new_str="y") + assert result["is_error"] is True + assert "not unique" in result["error"] + # File is untouched on error. + assert (workspace / "f.txt").read_text() == "x\nx\n" + + +async def test_insert_at_line_zero_prepends(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "f.txt").write_text("line1\nline2\n") + result = await tool.execute( + command="insert_at_line", path="f.txt", insert_line=0, new_str="TOP" + ) + assert result["is_error"] is False + assert (workspace / "f.txt").read_text() == "TOP\nline1\nline2\n" + + +async def test_insert_at_line_beyond_eof_appends( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + (workspace / "f.txt").write_text("line1\nline2\n") + result = await tool.execute( + command="insert_at_line", path="f.txt", insert_line=99, new_str="BOTTOM" + ) + assert result["is_error"] is False + assert (workspace / "f.txt").read_text() == "line1\nline2\nBOTTOM\n" + + +async def test_insert_at_line_multiline_text(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "f.txt").write_text("a\nb\n") + result = await tool.execute( + command="insert_at_line", + path="f.txt", + insert_line=2, + new_str="x\ny\nz", + ) + assert result["is_error"] is False + assert (workspace / "f.txt").read_text() == "a\nx\ny\nz\nb\n" + + +async def test_view_with_line_range(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "f.txt").write_text("one\ntwo\nthree\nfour\nfive\n") + result = await tool.execute(command="view", path="f.txt", start_line=2, end_line=4) + assert result["is_error"] is False + assert result["start_line"] == 2 + assert result["end_line"] == 4 + assert result["total_lines"] == 5 + assert result["content"] == " 2\ttwo\n 3\tthree\n 4\tfour" + + +async def test_view_range_beyond_eof_returns_empty( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + (workspace / "f.txt").write_text("only\n") + result = await tool.execute(command="view", path="f.txt", start_line=10, end_line=20) + assert result["is_error"] is False + assert result["content"] == "" + assert result["start_line"] == 10 + + +# ── error and failure paths ─────────────────────────────────────────── + + +async def test_create_refuses_overwrite(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "f.txt").write_text("existing\n") + result = await tool.execute(command="create", path="f.txt", file_text="new\n") + assert result["is_error"] is True + assert "already exists" in result["error"] + # Original content preserved (data-loss guard). + assert (workspace / "f.txt").read_text() == "existing\n" + + +async def test_str_replace_anchor_not_found(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "f.txt").write_text("hello world\n") + result = await tool.execute( + command="str_replace", path="f.txt", old_str="goodbye", new_str="hi" + ) + assert result["is_error"] is True + assert "not found" in result["error"] + + +async def test_str_replace_empty_old_str_rejected( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + (workspace / "f.txt").write_text("x\n") + result = await tool.execute(command="str_replace", path="f.txt", old_str="", new_str="y") + assert result["is_error"] is True + assert "old_str" in result["error"] + + +async def test_str_replace_on_missing_file(tool: StrReplaceEditorTool, workspace: Path) -> None: + result = await tool.execute(command="str_replace", path="nope.txt", old_str="a", new_str="b") + assert result["is_error"] is True + assert "not found" in result["error"].lower() + + +async def test_path_traversal_rejected(tool: StrReplaceEditorTool, workspace: Path) -> None: + result = await tool.execute(command="view", path="../../etc/passwd") + assert result["is_error"] is True + assert "rejected" in result["error"] + + +async def test_path_traversal_create_rejected( + tool: StrReplaceEditorTool, workspace: Path, tmp_path: Path +) -> None: + # Even if the target would resolve inside a sibling dir, `..` is rejected. + result = await tool.execute(command="create", path="../sibling.txt", file_text="x") + assert result["is_error"] is True + + +async def test_absolute_path_rejected(tool: StrReplaceEditorTool, workspace: Path) -> None: + # Absolute path to a real file outside the workspace. + result = await tool.execute(command="view", path="/etc/passwd") + assert result["is_error"] is True + assert "rejected" in result["error"] + + +async def test_absolute_path_inside_workspace_also_rejected( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + # Absolute paths are rejected outright (force relative interpretation), + # even when the path would resolve inside the workspace. + target = workspace / "inside.txt" + target.write_text("ok\n") + result = await tool.execute(command="view", path=str(target)) + assert result["is_error"] is True + assert "rejected" in result["error"] + + +async def test_symlink_escape_rejected(tmp_path: Path) -> None: + # Use a workspace SUBDIR of tmp_path so a file under tmp_path (but not + # under the workspace) counts as "outside the workspace". + workspace = tmp_path / "ws" + workspace.mkdir() + tool = StrReplaceEditorTool(workspace_root=workspace) + # Real secret file OUTSIDE the workspace (sibling, still under tmp_path). + outside = tmp_path / "secret.txt" + outside.write_text("top secret\n") + # Symlink inside the workspace pointing to the outside file. + link = workspace / "escape.txt" + os.symlink(outside, link) + # view through the symlink must be rejected (symlink escape). + result = await tool.execute(command="view", path="escape.txt") + assert result["is_error"] is True + assert "rejected" in result["error"] + # create through a symlink that escapes must also be rejected. + result2 = await tool.execute( + command="create", path="escape.txt", file_text="overwrite\n" + ) + assert result2["is_error"] is True + # The outside file must NOT have been overwritten (data-loss guard). + assert outside.read_text() == "top secret\n" + + +async def test_symlink_to_inside_workspace_allowed( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + # A symlink whose target is INSIDE the workspace is allowed (no escape). + real = workspace / "real.txt" + real.write_text("content\n") + link = workspace / "link.txt" + os.symlink(real, link) + result = await tool.execute(command="view", path="link.txt") + assert result["is_error"] is False + assert "content" in result["content"] + + +async def test_file_outside_workspace_rejected(tool: StrReplaceEditorTool, tmp_path: Path) -> None: + # A relative path that climbs out via `..` is rejected by the `..` rule, + # but also verify a nested traversal attempt is caught. + result = await tool.execute(command="view", path="sub/../../etc/passwd") + assert result["is_error"] is True + + +async def test_unknown_command_rejected(tool: StrReplaceEditorTool, workspace: Path) -> None: + result = await tool.execute(command="delete", path="f.txt") + assert result["is_error"] is True + assert "Unknown command" in result["error"] + + +async def test_missing_path_rejected(tool: StrReplaceEditorTool) -> None: + result = await tool.execute(command="view", path="") + assert result["is_error"] is True + assert "path" in result["error"].lower() + + +async def test_missing_file_text_rejected(tool: StrReplaceEditorTool, workspace: Path) -> None: + result = await tool.execute(command="create", path="f.txt") + assert result["is_error"] is True + assert "file_text" in result["error"] + + +async def test_missing_insert_line_rejected(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "f.txt").write_text("a\n") + result = await tool.execute(command="insert_at_line", path="f.txt", new_str="b") + assert result["is_error"] is True + assert "insert_line" in result["error"] + + +async def test_insert_line_negative_rejected(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "f.txt").write_text("a\n") + result = await tool.execute(command="insert_at_line", path="f.txt", insert_line=-1, new_str="b") + assert result["is_error"] is True + + +async def test_view_directory_rejected(tool: StrReplaceEditorTool, workspace: Path) -> None: + (workspace / "subdir").mkdir() + result = await tool.execute(command="view", path="subdir") + assert result["is_error"] is True + assert "directory" in result["error"].lower() + + +async def test_create_in_nested_subdir_creates_parents( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + result = await tool.execute( + command="create", + path="nested/deep/file.txt", + file_text="deep\n", + ) + assert result["is_error"] is False + assert (workspace / "nested" / "deep" / "file.txt").read_text() == "deep\n" + + +# ── integration contract ────────────────────────────────────────────── + + +def test_str_replace_editor_in_default_core_tools() -> None: + """The tool must be a default core tool so its full description is + always injected into the LLM prompt (tiered injection).""" + from agentkit.core.react import ReActEngine + + assert "str_replace_editor" in ReActEngine._DEFAULT_CORE_TOOLS + # The broken write_file placeholder must be gone. + assert "write_file" not in ReActEngine._DEFAULT_CORE_TOOLS + + +def test_tool_exported_from_tools_package() -> None: + from agentkit.tools import StrReplaceEditorTool as Exported + + assert Exported is StrReplaceEditorTool + + +def test_tool_name_and_schema(tool: StrReplaceEditorTool) -> None: + assert tool.name == "str_replace_editor" + assert tool.input_schema is not None + props = tool.input_schema["properties"] + assert "command" in props + assert set(props["command"]["enum"]) == { + "create", + "str_replace", + "insert_at_line", + "view", + } + # Description mentions all four commands so the LLM knows what it can do. + assert "create" in tool.description + assert "str_replace" in tool.description + assert "insert_at_line" in tool.description + assert "view" in tool.description + + +def test_tool_appears_in_prompt_when_registered() -> None: + """When a StrReplaceEditorTool is in the tool list and is a default core + tool, its full description (name + parameters) must appear in the + ReActEngine tool-use prompt (tiered injection contract).""" + from unittest.mock import MagicMock + + from agentkit.core.react import ReActEngine + + engine = ReActEngine(llm_gateway=MagicMock(), max_steps=1) + prompt = engine._build_tool_use_prompt([StrReplaceEditorTool()]) + # Full description injected (core tool). + assert "str_replace_editor" in prompt + assert "create" in prompt and "str_replace" in prompt + assert "insert_at_line" in prompt and "view" in prompt + + +# ── end-to-end workflow ─────────────────────────────────────────────── + + +async def test_create_view_str_replace_insert_workflow( + tool: StrReplaceEditorTool, workspace: Path +) -> None: + # 1. create + created = await tool.execute( + command="create", + path="app.py", + file_text="def main():\n pass\n", + ) + assert created["is_error"] is False + + # 2. view (get exact anchors / line numbers) + viewed = await tool.execute(command="view", path="app.py") + assert viewed["is_error"] is False + assert " 1\tdef main():" in viewed["content"] + + # 3. str_replace + replaced = await tool.execute( + command="str_replace", + path="app.py", + old_str=" pass", + new_str=" return 42", + ) + assert replaced["is_error"] is False + + # 4. insert_at_line (add a docstring at the top) + inserted = await tool.execute( + command="insert_at_line", + path="app.py", + insert_line=0, + new_str='"""Module doc."""', + ) + assert inserted["is_error"] is False + + final = (workspace / "app.py").read_text() + assert final == '"""Module doc."""\ndef main():\n return 42\n' + + +if __name__ == "__main__": + # Allow direct execution for a quick smoke check without pytest. + sys.exit(pytest.main([__file__, "-x", "-q"])) diff --git a/tests/unit/tools/test_tool_search.py b/tests/unit/tools/test_tool_search.py index 43b65ec..077c651 100644 --- a/tests/unit/tools/test_tool_search.py +++ b/tests/unit/tools/test_tool_search.py @@ -382,7 +382,7 @@ class TestReActTieredInjection: engine = self._make_engine() tools = [ FakeTool(name="read_file", description="Read a file."), - FakeTool(name="write_file", description="Write a file."), + FakeTool(name="str_replace_editor", description="Edit a file."), ] result = engine._maybe_add_tool_search(tools) assert len(result) == 2