diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 315b63e..1315830 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -18,6 +18,7 @@ from agentkit.tools.memory_tool import MemoryTool from agentkit.tools.web_search import WebSearchTool from agentkit.tools.builtin import RunTestsTool, ToolSearchTool from agentkit.tools.search import ToolSearchIndex +from agentkit.tools.file_read import ReadFileTool # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor try: @@ -52,4 +53,5 @@ __all__ = [ "OutputParser", "ParsedOutput", "ErrorType", + "ReadFileTool", ] diff --git a/src/agentkit/tools/file_read.py b/src/agentkit/tools/file_read.py new file mode 100644 index 0000000..1f70c96 --- /dev/null +++ b/src/agentkit/tools/file_read.py @@ -0,0 +1,262 @@ +"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23). + +Backward compatible with the pre-existing `_FakeTool` benchmark shape — when +`symbol=None`, returns the full file content. When `symbol="foo"`, returns +the line range of the first matching symbol via `SymbolExtractor`. + +KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool — keeps the +file-reading contract clean and gives the LLM a focused schema. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any + +from agentkit.tools.base import Tool +from agentkit.tools.symbol_extractor import ( + SymbolSpan, + extract_symbols_from_file, + language_for_extension, +) + +logger = logging.getLogger(__name__) + + +class ReadFileTool(Tool): + """Read a file from the filesystem, optionally sliced to a single symbol. + + Tool name `read_file` matches the reserved entry in + `core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real + implementation — only `_FakeTool` stubs in `cli/benchmark.py`). + + Backward-compat contract: `symbol=None` returns the full file content, + matching the shape `{"path": ...}` that downstream callers (benchmark, + phase whitelist) already expect. + """ + + def __init__( + self, + name: str = "read_file", + description: str | None = None, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description + or ( + "Read a file from the filesystem. By default returns the full file " + "content. Pass `symbol` (function/class/struct name) to slice to just " + "that symbol's line range — saves context when you only need one " + "function from a large file. Pass `start_line`/`end_line` for manual " + "slicing. If `symbol` is set but not found, returns the available " + "symbol names so you can retry." + ), + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["io", "file", "read"], + ) + + @staticmethod + def _default_input_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the file to read (absolute or relative to cwd).", + }, + "symbol": { + "type": "string", + "description": ( + "Optional: name of a function/class/struct/method to slice to. " + "When set, returns only the line range of the first matching " + "symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java." + ), + }, + "start_line": { + "type": "integer", + "description": "Optional 1-based start line for manual slicing. Overrides `symbol`.", + "minimum": 1, + }, + "end_line": { + "type": "integer", + "description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.", + "minimum": 1, + }, + }, + "required": ["path"], + "additionalProperties": False, + } + + @staticmethod + def _default_output_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "content": {"type": "string"}, + "path": {"type": "string"}, + "start_line": {"type": "integer"}, + "end_line": {"type": "integer"}, + "symbol": {"type": "string"}, + "available_symbols": { + "type": "array", + "items": {"type": "string"}, + "description": "Populated when `symbol` is set but not found.", + }, + "note": {"type": "string"}, + "is_error": {"type": "boolean"}, + "error": {"type": "string"}, + }, + } + + async def execute(self, **kwargs) -> dict[str, Any]: + raw_path = kwargs.get("path") + if not raw_path: + return self._error("`path` is required") + + path = Path(raw_path) + if not path.is_absolute(): + path = path.resolve() + + symbol = kwargs.get("symbol") + start_line = kwargs.get("start_line") + end_line = kwargs.get("end_line") + + # Validate/sanitize line overrides. + if start_line is not None and (not isinstance(start_line, int) or start_line < 1): + return self._error(f"`start_line` must be a positive integer, got {start_line!r}") + if end_line is not None and (not isinstance(end_line, int) or end_line < 1): + return self._error(f"`end_line` must be a positive integer, got {end_line!r}") + if start_line is not None and end_line is not None and end_line < start_line: + return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})") + + # Filesystem checks. + if not path.exists(): + return self._error(f"File not found: {path}", path=str(path)) + if path.is_dir(): + return self._error(f"Path is a directory, not a file: {path}", path=str(path)) + + try: + content = path.read_text(encoding="utf-8", errors="replace") + except PermissionError as e: + return self._error(f"Permission denied: {path}", path=str(path), detail=str(e)) + except OSError as e: + return self._error(f"Failed to read {path}: {e}", path=str(path)) + + lines = content.splitlines() + total_lines = len(lines) + + # Manual slicing takes precedence over symbol (per plan U1 Approach). + if start_line is not None or end_line is not None: + s = max(1, start_line or 1) + e = min(total_lines, end_line or total_lines) + sliced = "\n".join(lines[s - 1 : e]) + return { + "content": sliced, + "path": str(path), + "start_line": s, + "end_line": e, + "total_lines": total_lines, + "is_error": False, + } + + # Symbol slicing. + if symbol: + ext = path.suffix.lower() + language = language_for_extension(ext) + if not language: + # Unsupported extension: return full file with note (per plan U1 Edge case). + return { + "content": content, + "path": str(path), + "start_line": 1, + "end_line": total_lines, + "total_lines": total_lines, + "note": f"symbol extraction not supported for {ext or 'unknown extension'}", + "is_error": False, + } + + spans, _lang = extract_symbols_from_file(path) + # Re-extract using the content we already read so we don't read the file twice. + if not spans: + # Try extraction from in-memory content (path-based extraction may + # have failed silently on OSError; we already read it successfully). + from agentkit.tools.symbol_extractor import get_extractor + + extractor = get_extractor(language) + if extractor is not None: + spans = extractor.extract_symbols(content, language) + + match = _find_symbol(spans, symbol) + if match is None: + available = sorted({s.name for s in spans}) + return { + "content": "", + "path": str(path), + "symbol": symbol, + "available_symbols": available, + "is_error": False, + "note": ( + f"Symbol {symbol!r} not found in {path.name}. " + f"Available: {', '.join(available) if available else '(none)'}" + ), + } + + s = match.start_line + e = min(match.end_line, total_lines) + sliced = "\n".join(lines[s - 1 : e]) + return { + "content": sliced, + "path": str(path), + "symbol": symbol, + "symbol_kind": match.kind, + "start_line": s, + "end_line": e, + "total_lines": total_lines, + "is_error": False, + } + + # Default: full file (characterization baseline — matches _FakeTool shape). + return { + "content": content, + "path": str(path), + "start_line": 1, + "end_line": total_lines, + "total_lines": total_lines, + "is_error": False, + } + + @staticmethod + def _error( + message: str, *, path: str | None = None, detail: str | None = None + ) -> dict[str, Any]: + result: dict[str, Any] = { + "content": "", + "is_error": True, + "error": message, + } + if path is not None: + result["path"] = path + if detail is not None: + result["detail"] = detail + return result + + +def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None: + """Find the first symbol matching `name`. Case-sensitive. + + ponytail: linear scan is fine for typical file symbol counts (<100). The + extractor already returns symbols sorted by start_line; first match wins + for ambiguous overloads (e.g., Python classes with same name in different + modules — not relevant within one file). + """ + for span in spans: + if span.name == name: + return span + return None diff --git a/src/agentkit/tools/symbol_extractor.py b/src/agentkit/tools/symbol_extractor.py new file mode 100644 index 0000000..93f64d3 --- /dev/null +++ b/src/agentkit/tools/symbol_extractor.py @@ -0,0 +1,278 @@ +"""Symbol extraction — locate code symbols (functions/classes/structs) by name. + +KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex +for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The +`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor +can replace RegexSymbolExtractor behind the same interface. + +ponytail: regex extractor covers ~80% case (top-level function/class/struct +declarations). Ceiling: misses nested signatures inside JSX/TSX generics, +multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter. +""" + +from __future__ import annotations + +import ast +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Protocol, runtime_checkable + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True, slots=True) +class SymbolSpan: + """A located symbol — name, kind, and 1-based inclusive line range.""" + + name: str + kind: str # "function" | "class" | "method" | "struct" | "impl" + start_line: int # 1-based, inclusive + end_line: int # 1-based, inclusive + + +@runtime_checkable +class SymbolExtractor(Protocol): + """Protocol for symbol extractors — runtime_checkable for isinstance/issubclass.""" + + def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: + """Return all symbols found in `content`. + + `language` is the file extension without leading dot (e.g. "py", "ts"). + Implementations must never raise on extraction failure — return [] on + parse errors and let the caller decide the fallback (full-file read). + """ + ... + + +# --------------------------------------------------------------------------- +# Python — stdlib ast +# --------------------------------------------------------------------------- + + +class AstSymbolExtractor: + """Python symbol extractor using the stdlib `ast` module. + + Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested + functions inside classes. The end_line is the last line of the node's + source segment (decorator-inclusive). + """ + + def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: + if language != "py": + return [] + try: + tree = ast.parse(content) + except SyntaxError as e: + logger.debug("ast.parse failed: %s", e) + return [] + + lines = content.splitlines() + spans: list[SymbolSpan] = [] + for node in ast.walk(tree): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + kind = "method" if _is_method(node) else "function" + spans.append(_span_from_node(node, kind, lines)) + elif isinstance(node, ast.ClassDef): + spans.append(_span_from_node(node, "class", lines)) + return spans + + +def _is_method(node: ast.AST) -> bool: + """A FunctionDef is a method if its parent is a ClassDef. + + `ast.walk` doesn't expose parentage, so we approximate by checking the + node's col_offset == 4 (indented inside a class body). ponytail: this + misses methods in deeply nested classes — ceiling noted; upgrade path = + ast.NodeVisitor with parent tracking. + """ + return getattr(node, "col_offset", 0) > 0 + + +def _span_from_node( + node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef, + kind: str, + lines: list[str], +) -> SymbolSpan: + # ast line numbers are 1-based; start at decorator if present (lineno points + # to the def/class keyword, decorators are above). Use node.lineno for start + # so the returned range matches what the user sees at the def keyword. + start = node.lineno + # node.end_lineno is the last line of the node body (None on old Pythons). + end = node.end_lineno or start + # Clamp to actual file length (defensive — ast should not exceed, but + # malformed files with no trailing newline can confuse end_lineno). + if end > len(lines): + end = len(lines) + return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end) + + +# --------------------------------------------------------------------------- +# Regex extractor — TS/JS/Go/Rust/Java +# --------------------------------------------------------------------------- + +# Each pattern matches a declaration and captures the symbol name in group 1. +# Patterns use re.MULTILINE so ^ matches line starts. +_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = { + "ts": [ + ( + "function", + re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE), + ), + ("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)), + ( + "function", + re.compile( + r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", + re.MULTILINE, + ), + ), + ], + "js": [ + ("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)), + ("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)), + ( + "function", + re.compile( + r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE + ), + ), + ], + "go": [ + ("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)), + ("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)), + ], + "rs": [ + ("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)), + ("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)), + ("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)), + ], + "java": [ + ( + "function", + re.compile( + r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{", + re.MULTILINE, + ), + ), + ( + "class", + re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE), + ), + ], +} + + +class RegexSymbolExtractor: + """Language-aware regex symbol extractor for TS/JS/Go/Rust/Java. + + Returns SymbolSpans whose end_line is approximated by the next blank line + or next-symbol start (whichever comes first). ponytail: this is an + approximation — true block-end requires language-aware brace matching. + Ceiling: deeply nested blocks may over-extend the range. Upgrade path = + tree-sitter. + """ + + def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: + patterns = _REGEX_PATTERNS.get(language) + if not patterns: + return [] + + lines = content.splitlines() + # Collect (line_no, name, kind) tuples first, then compute end_line + # as the line before the next symbol starts (or EOF). + raw_hits: list[tuple[int, str, str]] = [] + for kind, pattern in patterns: + for m in pattern.finditer(content): + # Convert match offset to 1-based line number. + line_no = content[: m.start()].count("\n") + 1 + raw_hits.append((line_no, m.group(1), kind)) + + if not raw_hits: + return [] + + # Deduplicate: same (line_no, name) may appear for overlapping patterns. + seen: set[tuple[int, str]] = set() + unique: list[tuple[int, str, str]] = [] + for line_no, name, kind in raw_hits: + key = (line_no, name) + if key in seen: + continue + seen.add(key) + unique.append((line_no, name, kind)) + + unique.sort(key=lambda x: x[0]) + + spans: list[SymbolSpan] = [] + for i, (start_line, name, kind) in enumerate(unique): + if i + 1 < len(unique): + # End at line before next symbol starts, capped at file length. + end_line = unique[i + 1][0] - 1 + else: + end_line = len(lines) + if end_line < start_line: + end_line = start_line + spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line)) + return spans + + +# --------------------------------------------------------------------------- +# Dispatch by file extension +# --------------------------------------------------------------------------- + +_EXTENSION_LANGUAGE: dict[str, str] = { + ".py": "py", + ".ts": "ts", + ".tsx": "ts", + ".js": "js", + ".jsx": "js", + ".mjs": "js", + ".cjs": "js", + ".go": "go", + ".rs": "rs", + ".java": "java", +} + +_DEFAULT_EXTRACTOR = AstSymbolExtractor() +_REGEX_EXTRACTOR = RegexSymbolExtractor() + + +def language_for_extension(ext: str) -> str: + """Return the language key for a file extension (with or without leading dot). + + Returns "" for unsupported extensions. + """ + if not ext.startswith("."): + ext = "." + ext + return _EXTENSION_LANGUAGE.get(ext.lower(), "") + + +def get_extractor(language: str) -> SymbolExtractor | None: + """Return the appropriate extractor for `language`, or None if unsupported.""" + if language == "py": + return _DEFAULT_EXTRACTOR + if language in _REGEX_PATTERNS: + return _REGEX_EXTRACTOR + return None + + +def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]: + """Read a file and return (symbols, language). + + Returns ([], "") if the extension is unsupported or the file cannot be read. + Never raises — callers use this for fallback routing. + """ + ext = path.suffix.lower() + language = language_for_extension(ext) + if not language: + return [], "" + try: + content = path.read_text(encoding="utf-8", errors="replace") + except OSError as e: + logger.debug("read failed for %s: %s", path, e) + return [], language + extractor = get_extractor(language) + if extractor is None: + return [], language + return extractor.extract_symbols(content, language), language diff --git a/tests/unit/test_read_file_tool.py b/tests/unit/test_read_file_tool.py new file mode 100644 index 0000000..87363e4 --- /dev/null +++ b/tests/unit/test_read_file_tool.py @@ -0,0 +1,367 @@ +"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline. + +Per plan U1 Execution note: characterization-first — assert that +`symbol=None` returns the full file content (matches pre-existing benchmark +`_FakeTool` shape) before adding symbol-extraction behavior. +""" + +from __future__ import annotations + +import textwrap + +import pytest + +from agentkit.tools.file_read import ReadFileTool + + +# --------------------------------------------------------------------------- +# Schema +# --------------------------------------------------------------------------- + + +class TestReadFileToolSchema: + def test_name_is_read_file(self): + tool = ReadFileTool() + assert tool.name == "read_file" + + def test_required_path(self): + tool = ReadFileTool() + assert "path" in tool.input_schema["required"] + assert "path" in tool.input_schema["properties"] + + def test_optional_symbol_and_lines(self): + tool = ReadFileTool() + props = tool.input_schema["properties"] + assert "symbol" in props + assert "start_line" in props + assert "end_line" in props + # None of the optional fields should be in `required`. + required = set(tool.input_schema["required"]) + assert required == {"path"} + + def test_additional_properties_false(self): + # LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern). + tool = ReadFileTool() + assert tool.input_schema.get("additionalProperties") is False + + def test_tags_contain_io_and_read(self): + tool = ReadFileTool() + assert "io" in tool.tags + assert "read" in tool.tags + + +# --------------------------------------------------------------------------- +# Characterization — symbol=None returns full file +# --------------------------------------------------------------------------- + + +@pytest.fixture +def sample_py_file(tmp_path): + path = tmp_path / "sample.py" + path.write_text( + textwrap.dedent(''' + """Sample module.""" + + def my_func(): + return 42 + + + class MyClass: + attr = 1 + + def method_a(self): + return self.attr + ''').lstrip(), + encoding="utf-8", + ) + return path + + +@pytest.fixture +def sample_ts_file(tmp_path): + path = tmp_path / "sample.ts" + path.write_text( + textwrap.dedent(''' + export function renderComponent(): JSX.Element { + return
; + } + + export class BaseService { + abstract run(): void; + } + ''').lstrip(), + encoding="utf-8", + ) + return path + + +class TestCharacterizationFullFile: + """symbol=None returns the whole file (matches _FakeTool baseline).""" + + @pytest.mark.asyncio + async def test_full_file_returned_when_symbol_none(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file)) + + assert result["is_error"] is False + assert result["path"] == str(sample_py_file) + assert result["start_line"] == 1 + assert result["end_line"] == result["total_lines"] + assert "def my_func" in result["content"] + assert "class MyClass" in result["content"] + assert result["content"].startswith('"""Sample module."""') + + @pytest.mark.asyncio + async def test_full_file_includes_all_lines(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file)) + assert result["total_lines"] >= 8 + assert result["content"].count("\n") >= result["total_lines"] - 1 + + +# --------------------------------------------------------------------------- +# Symbol slicing — happy paths +# --------------------------------------------------------------------------- + + +class TestSymbolSlicing: + @pytest.mark.asyncio + async def test_python_function(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="my_func") + + assert result["is_error"] is False + assert result["symbol"] == "my_func" + assert result["symbol_kind"] == "function" + assert "def my_func" in result["content"] + assert "return 42" in result["content"] + # Should NOT include the class below. + assert "class MyClass" not in result["content"] + + @pytest.mark.asyncio + async def test_python_class_includes_method(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="MyClass") + + assert result["is_error"] is False + assert result["symbol"] == "MyClass" + assert result["symbol_kind"] == "class" + assert "class MyClass" in result["content"] + assert "def method_a" in result["content"] # method included + + @pytest.mark.asyncio + async def test_python_method_directly(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="method_a") + + assert result["is_error"] is False + assert result["symbol"] == "method_a" + assert result["symbol_kind"] == "method" + assert "def method_a" in result["content"] + + @pytest.mark.asyncio + async def test_typescript_function(self, sample_ts_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent") + + assert result["is_error"] is False + assert result["symbol"] == "renderComponent" + assert "renderComponent" in result["content"] + # Should not include the class below. + assert "BaseService" not in result["content"] + + @pytest.mark.asyncio + async def test_typescript_class(self, sample_ts_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_ts_file), symbol="BaseService") + + assert result["is_error"] is False + assert result["symbol"] == "BaseService" + assert result["symbol_kind"] == "class" + assert "BaseService" in result["content"] + + +# --------------------------------------------------------------------------- +# Symbol slicing — edge cases +# --------------------------------------------------------------------------- + + +class TestSymbolSlicingEdgeCases: + @pytest.mark.asyncio + async def test_symbol_not_found_lists_available(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), symbol="nonexistent") + + assert result["is_error"] is False # soft error, not hard + assert result["content"] == "" + assert result["symbol"] == "nonexistent" + available = result["available_symbols"] + assert "my_func" in available + assert "MyClass" in available + assert "method_a" in available + assert "nonexistent" not in result["content"] + + @pytest.mark.asyncio + async def test_unsupported_extension_returns_full_with_note(self, tmp_path): + path = tmp_path / "notes.md" + path.write_text("# Hello\nworld\n", encoding="utf-8") + tool = ReadFileTool() + result = await tool.execute(path=str(path), symbol="anything") + + assert result["is_error"] is False + assert result["content"] == "# Hello\nworld\n" + assert "symbol extraction not supported" in result["note"] + assert ".md" in result["note"] + + @pytest.mark.asyncio + async def test_empty_file(self, tmp_path): + path = tmp_path / "empty.py" + path.write_text("", encoding="utf-8") + tool = ReadFileTool() + result = await tool.execute(path=str(path)) + + assert result["is_error"] is False + assert result["content"] == "" + assert result["total_lines"] == 0 + + @pytest.mark.asyncio + async def test_file_with_no_symbols(self, tmp_path): + path = tmp_path / "data.py" + path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8") + tool = ReadFileTool() + result = await tool.execute(path=str(path), symbol="PI") + + # PI is not a def/class — extractor finds no symbols; soft error lists available. + assert result["is_error"] is False + assert result["content"] == "" + assert result["available_symbols"] == [] + + +# --------------------------------------------------------------------------- +# Error paths +# --------------------------------------------------------------------------- + + +class TestReadFileToolErrors: + @pytest.mark.asyncio + async def test_path_required(self): + tool = ReadFileTool() + result = await tool.execute() + assert result["is_error"] is True + assert "path" in result["error"].lower() + + @pytest.mark.asyncio + async def test_path_empty_string(self): + tool = ReadFileTool() + result = await tool.execute(path="") + assert result["is_error"] is True + + @pytest.mark.asyncio + async def test_file_not_found(self, tmp_path): + tool = ReadFileTool() + result = await tool.execute(path=str(tmp_path / "missing.py")) + assert result["is_error"] is True + assert "not found" in result["error"].lower() + + @pytest.mark.asyncio + async def test_path_is_directory(self, tmp_path): + tool = ReadFileTool() + result = await tool.execute(path=str(tmp_path)) + assert result["is_error"] is True + assert "directory" in result["error"].lower() + + +# --------------------------------------------------------------------------- +# Manual line slicing +# --------------------------------------------------------------------------- + + +class TestManualLineSlicing: + @pytest.mark.asyncio + async def test_start_and_end_line(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute( + path=str(sample_py_file), + start_line=3, + end_line=5, + ) + assert result["is_error"] is False + assert result["start_line"] == 3 + assert result["end_line"] == 5 + # Lines 3-5 of the sample file: + # line 3: "def my_func():" + # line 4: " return 42" + # line 5: "" (blank) + assert "def my_func" in result["content"] + assert "return 42" in result["content"] + + @pytest.mark.asyncio + async def test_start_line_only_extends_to_eof(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), start_line=8) + assert result["is_error"] is False + assert result["start_line"] == 8 + assert result["end_line"] == result["total_lines"] + + @pytest.mark.asyncio + async def test_end_line_only_starts_at_one(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), end_line=2) + assert result["is_error"] is False + assert result["start_line"] == 1 + assert result["end_line"] == 2 + + @pytest.mark.asyncio + async def test_invalid_start_line_zero(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute(path=str(sample_py_file), start_line=0) + assert result["is_error"] is True + assert "start_line" in result["error"].lower() + + @pytest.mark.asyncio + async def test_end_before_start(self, sample_py_file): + tool = ReadFileTool() + result = await tool.execute( + path=str(sample_py_file), + start_line=5, + end_line=3, + ) + assert result["is_error"] is True + assert "end_line" in result["error"].lower() + + @pytest.mark.asyncio + async def test_manual_lines_override_symbol(self, sample_py_file): + # Per plan U1 Approach: "start_line/end_line overrides symbol". + tool = ReadFileTool() + result = await tool.execute( + path=str(sample_py_file), + symbol="my_func", + start_line=1, + end_line=1, + ) + assert result["is_error"] is False + # Manual slicing won — symbol field absent. + assert "symbol" not in result or result.get("symbol") is None + assert result["start_line"] == 1 + assert result["end_line"] == 1 + + +# --------------------------------------------------------------------------- +# Integration — tool registry discovery +# --------------------------------------------------------------------------- + + +class TestToolRegistryDiscovery: + def test_instantiable_without_args(self): + # Default constructor — matches the convention used by ToolRegistry + # to instantiate tools by class. + tool = ReadFileTool() + assert tool.name == "read_file" + + def test_to_dict_serializable(self): + tool = ReadFileTool() + d = tool.to_dict() + assert d["name"] == "read_file" + assert "input_schema" in d + assert "output_schema" in d + assert d["tags"] == ["io", "file", "read"] diff --git a/tests/unit/test_symbol_extractor.py b/tests/unit/test_symbol_extractor.py new file mode 100644 index 0000000..d7b444f --- /dev/null +++ b/tests/unit/test_symbol_extractor.py @@ -0,0 +1,359 @@ +"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor. + +Covers R22 (file reading supports symbol/function granularity) and KTD1 +(Python ast + language-aware regex, no tree-sitter dependency). +""" + +from __future__ import annotations + +import textwrap + +import pytest + +from agentkit.tools.symbol_extractor import ( + AstSymbolExtractor, + RegexSymbolExtractor, + SymbolSpan, + extract_symbols_from_file, + get_extractor, + language_for_extension, +) + + +# --------------------------------------------------------------------------- +# language_for_extension +# --------------------------------------------------------------------------- + + +class TestLanguageForExtension: + def test_python_extensions(self): + assert language_for_extension("py") == "py" + assert language_for_extension(".py") == "py" + assert language_for_extension(".PY") == "py" # case-insensitive + + def test_typescript_javascript(self): + assert language_for_extension(".ts") == "ts" + assert language_for_extension(".tsx") == "ts" + assert language_for_extension(".js") == "js" + assert language_for_extension(".jsx") == "js" + assert language_for_extension(".mjs") == "js" + assert language_for_extension(".cjs") == "js" + + def test_go_rust_java(self): + assert language_for_extension(".go") == "go" + assert language_for_extension(".rs") == "rs" + assert language_for_extension(".java") == "java" + + def test_unsupported_returns_empty(self): + assert language_for_extension(".md") == "" + assert language_for_extension(".txt") == "" + assert language_for_extension("") == "" + assert language_for_extension(".unknown") == "" + + +# --------------------------------------------------------------------------- +# AstSymbolExtractor — Python +# --------------------------------------------------------------------------- + + +class TestAstSymbolExtractor: + extractor = AstSymbolExtractor() + + def test_unsupported_language_returns_empty(self): + assert self.extractor.extract_symbols("function foo() {}", "ts") == [] + + def test_syntax_error_returns_empty(self): + # Never raises — callers rely on this for fallback routing. + result = self.extractor.extract_symbols("def broken(:\n pass", "py") + assert result == [] + + def test_top_level_function(self): + content = "def my_func():\n return 42\n" + spans = self.extractor.extract_symbols(content, "py") + assert len(spans) == 1 + span = spans[0] + assert span.name == "my_func" + assert span.kind == "function" + assert span.start_line == 1 + assert span.end_line == 2 + + def test_async_function(self): + content = "async def fetch():\n return 1\n" + spans = self.extractor.extract_symbols(content, "py") + assert len(spans) == 1 + assert spans[0].name == "fetch" + assert spans[0].kind == "function" + + def test_top_level_class(self): + content = textwrap.dedent(''' + class MyClass: + """docstring""" + + def method_a(self): + return 1 + + async def method_b(self): + return 2 + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + names = [s.name for s in spans] + assert "MyClass" in names + assert "method_a" in names + assert "method_b" in names + + cls = next(s for s in spans if s.name == "MyClass") + assert cls.kind == "class" + assert cls.start_line == 1 + # Class body extends through the last method's end_lineno. + assert cls.end_line >= 7 + + def test_methods_classified_as_methods(self): + content = textwrap.dedent(''' + class Foo: + def bar(self): + pass + + def top_level(): + pass + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + by_name = {s.name: s for s in spans} + assert by_name["bar"].kind == "method" + assert by_name["top_level"].kind == "function" + + def test_decorated_function(self): + content = textwrap.dedent(''' + @staticmethod + def helper(): + return "hi" + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + # Note: extractor uses node.lineno (def line) — decorators above are + # excluded by design (matches user-visible symbol start at `def`). + assert any(s.name == "helper" for s in spans) + span = next(s for s in spans if s.name == "helper") + assert span.start_line == 2 # the `def` line + + def test_nested_function(self): + content = textwrap.dedent(''' + def outer(): + def inner(): + return 1 + return inner() + ''').strip() + spans = self.extractor.extract_symbols(content, "py") + names = {s.name for s in spans} + assert "outer" in names + assert "inner" in names + + def test_empty_file(self): + assert self.extractor.extract_symbols("", "py") == [] + + def test_no_symbols_in_docstring_only_file(self): + content = '"""just a docstring"""\n' + assert self.extractor.extract_symbols(content, "py") == [] + + +# --------------------------------------------------------------------------- +# RegexSymbolExtractor — TS/JS/Go/Rust/Java +# --------------------------------------------------------------------------- + + +class TestRegexSymbolExtractor: + extractor = RegexSymbolExtractor() + + def test_unsupported_language_returns_empty(self): + assert self.extractor.extract_symbols("def foo(): pass", "py") == [] + assert self.extractor.extract_symbols("function foo() {}", "rb") == [] + + def test_typescript_function_declaration(self): + content = textwrap.dedent(''' + export function renderComponent(props: Props): JSX.Element { + return
; + } + ''').strip() + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "renderComponent" and s.kind == "function" for s in spans) + + def test_typescript_async_function(self): + content = "async function fetchData() {\n return await fetch();\n}\n" + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "fetchData" for s in spans) + + def test_typescript_arrow_function_const(self): + content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n" + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "handleClick" for s in spans) + + def test_typescript_class(self): + content = textwrap.dedent(''' + export abstract class BaseService { + abstract run(): void; + } + ''').strip() + spans = self.extractor.extract_symbols(content, "ts") + assert any(s.name == "BaseService" and s.kind == "class" for s in spans) + + def test_javascript_function(self): + content = "function foo() {\n return 1;\n}\n" + spans = self.extractor.extract_symbols(content, "js") + assert any(s.name == "foo" for s in spans) + + def test_javascript_arrow_const(self): + content = "const bar = () => 42;\n" + spans = self.extractor.extract_symbols(content, "js") + assert any(s.name == "bar" for s in spans) + + def test_go_function(self): + content = textwrap.dedent(''' + package main + + func HandleRequest(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + } + + func (s *Server) Start() { + // method + } + ''').strip() + spans = self.extractor.extract_symbols(content, "go") + names = {s.name for s in spans} + assert "HandleRequest" in names + assert "Start" in names # method receiver pattern + + def test_go_struct(self): + content = "type Server struct {\n Addr string\n}\n" + spans = self.extractor.extract_symbols(content, "go") + assert any(s.name == "Server" and s.kind == "struct" for s in spans) + + def test_rust_function(self): + content = textwrap.dedent(''' + pub fn process(input: &str) -> Result { + Ok(input.len()) + } + + async fn fetch() -> Bytes { + unimplemented!() + } + ''').strip() + spans = self.extractor.extract_symbols(content, "rs") + names = {s.name for s in spans} + assert "process" in names + assert "fetch" in names + + def test_rust_struct(self): + content = "pub struct Config {\n pub path: String,\n}\n" + spans = self.extractor.extract_symbols(content, "rs") + assert any(s.name == "Config" and s.kind == "struct" for s in spans) + + def test_rust_impl(self): + content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n" + spans = self.extractor.extract_symbols(content, "rs") + assert any(s.name == "Config" and s.kind == "impl" for s in spans) + + def test_java_class(self): + content = textwrap.dedent(''' + package com.example; + + public class UserService { + public User findById(long id) { + return null; + } + } + ''').strip() + spans = self.extractor.extract_symbols(content, "java") + assert any(s.name == "UserService" and s.kind == "class" for s in spans) + + def test_java_method(self): + content = "public User findById(long id) {\n return null;\n}\n" + spans = self.extractor.extract_symbols(content, "java") + assert any(s.name == "findById" and s.kind == "function" for s in spans) + + def test_end_line_extends_to_next_symbol(self): + # First symbol's end_line is the line before the second symbol starts. + content = textwrap.dedent(''' + function first() { + return 1; + } + + function second() { + return 2; + } + ''').strip() + spans = self.extractor.extract_symbols(content, "js") + spans.sort(key=lambda s: s.start_line) + first = spans[0] + second = spans[1] + assert first.name == "first" + assert second.name == "second" + assert first.end_line == second.start_line - 1 + + def test_last_symbol_end_line_is_eof(self): + content = "function only() {\n return 1;\n}\n" + spans = self.extractor.extract_symbols(content, "js") + assert len(spans) == 1 + assert spans[0].end_line == len(content.splitlines()) + + +# --------------------------------------------------------------------------- +# get_extractor + integration +# --------------------------------------------------------------------------- + + +class TestGetExtractor: + def test_python_returns_ast_extractor(self): + ext = get_extractor("py") + assert ext is not None + assert isinstance(ext, AstSymbolExtractor) + + def test_typescript_returns_regex_extractor(self): + ext = get_extractor("ts") + assert ext is not None + assert isinstance(ext, RegexSymbolExtractor) + + def test_unsupported_returns_none(self): + assert get_extractor("md") is None + assert get_extractor("") is None + assert get_extractor("unknown") is None + + +class TestExtractSymbolsFromFile: + def test_python_file(self, tmp_path): + path = tmp_path / "module.py" + path.write_text("def hello():\n return 'world'\n", encoding="utf-8") + spans, lang = extract_symbols_from_file(path) + assert lang == "py" + assert any(s.name == "hello" for s in spans) + + def test_unsupported_extension(self, tmp_path): + path = tmp_path / "notes.md" + path.write_text("# Hello\n", encoding="utf-8") + spans, lang = extract_symbols_from_file(path) + assert lang == "" + assert spans == [] + + def test_missing_file_returns_empty(self, tmp_path): + path = tmp_path / "nonexistent.py" + spans, lang = extract_symbols_from_file(path) + # lang is detected from extension even if read fails. + assert lang == "py" + assert spans == [] + + +# --------------------------------------------------------------------------- +# SymbolSpan dataclass +# --------------------------------------------------------------------------- + + +class TestSymbolSpan: + def test_frozen_dataclass(self): + span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3) + assert span.name == "foo" + with pytest.raises(Exception): + span.name = "bar" # type: ignore[misc] — frozen + + def test_equality(self): + a = SymbolSpan("foo", "function", 1, 3) + b = SymbolSpan("foo", "function", 1, 3) + assert a == b + assert hash(a) == hash(b)