feat(U1): G5 SymbolExtractor + ReadFileTool with symbol slicing

- SymbolExtractor protocol + SymbolSpan dataclass
- AstSymbolExtractor for Python (stdlib ast, no tree-sitter dep — KTD1)
- RegexSymbolExtractor for TS/JS/Go/Rust/Java (language-aware regex)
- ReadFileTool with path/symbol/start_line/end_line params
- symbol=None returns full file (characterization baseline matching _FakeTool)
- symbol='foo' returns first matching symbol's line range
- symbol not found returns available_symbols list (soft error)
- Unsupported extension returns full file with note
- Manual start_line/end_line overrides symbol
- 66 unit tests covering R22/R23 + characterization + edge cases
This commit is contained in:
chiguyong 2026-06-29 23:54:44 +08:00
parent 23f7448d55
commit 50885fbc62
5 changed files with 1268 additions and 0 deletions

View File

@ -18,6 +18,7 @@ from agentkit.tools.memory_tool import MemoryTool
from agentkit.tools.web_search import WebSearchTool
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
from agentkit.tools.search import ToolSearchIndex
from agentkit.tools.file_read import ReadFileTool
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try:
@ -52,4 +53,5 @@ __all__ = [
"OutputParser",
"ParsedOutput",
"ErrorType",
"ReadFileTool",
]

View File

@ -0,0 +1,262 @@
"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23).
Backward compatible with the pre-existing `_FakeTool` benchmark shape when
`symbol=None`, returns the full file content. When `symbol="foo"`, returns
the line range of the first matching symbol via `SymbolExtractor`.
KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool keeps the
file-reading contract clean and gives the LLM a focused schema.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
from agentkit.tools.base import Tool
from agentkit.tools.symbol_extractor import (
SymbolSpan,
extract_symbols_from_file,
language_for_extension,
)
logger = logging.getLogger(__name__)
class ReadFileTool(Tool):
"""Read a file from the filesystem, optionally sliced to a single symbol.
Tool name `read_file` matches the reserved entry in
`core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real
implementation only `_FakeTool` stubs in `cli/benchmark.py`).
Backward-compat contract: `symbol=None` returns the full file content,
matching the shape `{"path": ...}` that downstream callers (benchmark,
phase whitelist) already expect.
"""
def __init__(
self,
name: str = "read_file",
description: str | None = None,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description
or (
"Read a file from the filesystem. By default returns the full file "
"content. Pass `symbol` (function/class/struct name) to slice to just "
"that symbol's line range — saves context when you only need one "
"function from a large file. Pass `start_line`/`end_line` for manual "
"slicing. If `symbol` is set but not found, returns the available "
"symbol names so you can retry."
),
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["io", "file", "read"],
)
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file to read (absolute or relative to cwd).",
},
"symbol": {
"type": "string",
"description": (
"Optional: name of a function/class/struct/method to slice to. "
"When set, returns only the line range of the first matching "
"symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java."
),
},
"start_line": {
"type": "integer",
"description": "Optional 1-based start line for manual slicing. Overrides `symbol`.",
"minimum": 1,
},
"end_line": {
"type": "integer",
"description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.",
"minimum": 1,
},
},
"required": ["path"],
"additionalProperties": False,
}
@staticmethod
def _default_output_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"content": {"type": "string"},
"path": {"type": "string"},
"start_line": {"type": "integer"},
"end_line": {"type": "integer"},
"symbol": {"type": "string"},
"available_symbols": {
"type": "array",
"items": {"type": "string"},
"description": "Populated when `symbol` is set but not found.",
},
"note": {"type": "string"},
"is_error": {"type": "boolean"},
"error": {"type": "string"},
},
}
async def execute(self, **kwargs) -> dict[str, Any]:
raw_path = kwargs.get("path")
if not raw_path:
return self._error("`path` is required")
path = Path(raw_path)
if not path.is_absolute():
path = path.resolve()
symbol = kwargs.get("symbol")
start_line = kwargs.get("start_line")
end_line = kwargs.get("end_line")
# Validate/sanitize line overrides.
if start_line is not None and (not isinstance(start_line, int) or start_line < 1):
return self._error(f"`start_line` must be a positive integer, got {start_line!r}")
if end_line is not None and (not isinstance(end_line, int) or end_line < 1):
return self._error(f"`end_line` must be a positive integer, got {end_line!r}")
if start_line is not None and end_line is not None and end_line < start_line:
return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})")
# Filesystem checks.
if not path.exists():
return self._error(f"File not found: {path}", path=str(path))
if path.is_dir():
return self._error(f"Path is a directory, not a file: {path}", path=str(path))
try:
content = path.read_text(encoding="utf-8", errors="replace")
except PermissionError as e:
return self._error(f"Permission denied: {path}", path=str(path), detail=str(e))
except OSError as e:
return self._error(f"Failed to read {path}: {e}", path=str(path))
lines = content.splitlines()
total_lines = len(lines)
# Manual slicing takes precedence over symbol (per plan U1 Approach).
if start_line is not None or end_line is not None:
s = max(1, start_line or 1)
e = min(total_lines, end_line or total_lines)
sliced = "\n".join(lines[s - 1 : e])
return {
"content": sliced,
"path": str(path),
"start_line": s,
"end_line": e,
"total_lines": total_lines,
"is_error": False,
}
# Symbol slicing.
if symbol:
ext = path.suffix.lower()
language = language_for_extension(ext)
if not language:
# Unsupported extension: return full file with note (per plan U1 Edge case).
return {
"content": content,
"path": str(path),
"start_line": 1,
"end_line": total_lines,
"total_lines": total_lines,
"note": f"symbol extraction not supported for {ext or 'unknown extension'}",
"is_error": False,
}
spans, _lang = extract_symbols_from_file(path)
# Re-extract using the content we already read so we don't read the file twice.
if not spans:
# Try extraction from in-memory content (path-based extraction may
# have failed silently on OSError; we already read it successfully).
from agentkit.tools.symbol_extractor import get_extractor
extractor = get_extractor(language)
if extractor is not None:
spans = extractor.extract_symbols(content, language)
match = _find_symbol(spans, symbol)
if match is None:
available = sorted({s.name for s in spans})
return {
"content": "",
"path": str(path),
"symbol": symbol,
"available_symbols": available,
"is_error": False,
"note": (
f"Symbol {symbol!r} not found in {path.name}. "
f"Available: {', '.join(available) if available else '(none)'}"
),
}
s = match.start_line
e = min(match.end_line, total_lines)
sliced = "\n".join(lines[s - 1 : e])
return {
"content": sliced,
"path": str(path),
"symbol": symbol,
"symbol_kind": match.kind,
"start_line": s,
"end_line": e,
"total_lines": total_lines,
"is_error": False,
}
# Default: full file (characterization baseline — matches _FakeTool shape).
return {
"content": content,
"path": str(path),
"start_line": 1,
"end_line": total_lines,
"total_lines": total_lines,
"is_error": False,
}
@staticmethod
def _error(
message: str, *, path: str | None = None, detail: str | None = None
) -> dict[str, Any]:
result: dict[str, Any] = {
"content": "",
"is_error": True,
"error": message,
}
if path is not None:
result["path"] = path
if detail is not None:
result["detail"] = detail
return result
def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None:
"""Find the first symbol matching `name`. Case-sensitive.
ponytail: linear scan is fine for typical file symbol counts (<100). The
extractor already returns symbols sorted by start_line; first match wins
for ambiguous overloads (e.g., Python classes with same name in different
modules not relevant within one file).
"""
for span in spans:
if span.name == name:
return span
return None

View File

@ -0,0 +1,278 @@
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
`SymbolExtractor` protocol is the upgrade seam a future TreeSitterSymbolExtractor
can replace RegexSymbolExtractor behind the same interface.
ponytail: regex extractor covers ~80% case (top-level function/class/struct
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
"""
from __future__ import annotations
import ast
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class SymbolSpan:
"""A located symbol — name, kind, and 1-based inclusive line range."""
name: str
kind: str # "function" | "class" | "method" | "struct" | "impl"
start_line: int # 1-based, inclusive
end_line: int # 1-based, inclusive
@runtime_checkable
class SymbolExtractor(Protocol):
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
"""Return all symbols found in `content`.
`language` is the file extension without leading dot (e.g. "py", "ts").
Implementations must never raise on extraction failure return [] on
parse errors and let the caller decide the fallback (full-file read).
"""
...
# ---------------------------------------------------------------------------
# Python — stdlib ast
# ---------------------------------------------------------------------------
class AstSymbolExtractor:
"""Python symbol extractor using the stdlib `ast` module.
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
functions inside classes. The end_line is the last line of the node's
source segment (decorator-inclusive).
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
if language != "py":
return []
try:
tree = ast.parse(content)
except SyntaxError as e:
logger.debug("ast.parse failed: %s", e)
return []
lines = content.splitlines()
spans: list[SymbolSpan] = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
kind = "method" if _is_method(node) else "function"
spans.append(_span_from_node(node, kind, lines))
elif isinstance(node, ast.ClassDef):
spans.append(_span_from_node(node, "class", lines))
return spans
def _is_method(node: ast.AST) -> bool:
"""A FunctionDef is a method if its parent is a ClassDef.
`ast.walk` doesn't expose parentage, so we approximate by checking the
node's col_offset == 4 (indented inside a class body). ponytail: this
misses methods in deeply nested classes ceiling noted; upgrade path =
ast.NodeVisitor with parent tracking.
"""
return getattr(node, "col_offset", 0) > 0
def _span_from_node(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
kind: str,
lines: list[str],
) -> SymbolSpan:
# ast line numbers are 1-based; start at decorator if present (lineno points
# to the def/class keyword, decorators are above). Use node.lineno for start
# so the returned range matches what the user sees at the def keyword.
start = node.lineno
# node.end_lineno is the last line of the node body (None on old Pythons).
end = node.end_lineno or start
# Clamp to actual file length (defensive — ast should not exceed, but
# malformed files with no trailing newline can confuse end_lineno).
if end > len(lines):
end = len(lines)
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
# ---------------------------------------------------------------------------
# Regex extractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
# Each pattern matches a declaration and captures the symbol name in group 1.
# Patterns use re.MULTILINE so ^ matches line starts.
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
"ts": [
(
"function",
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
),
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
re.MULTILINE,
),
),
],
"js": [
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
),
),
],
"go": [
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
],
"rs": [
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
],
"java": [
(
"function",
re.compile(
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
re.MULTILINE,
),
),
(
"class",
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
),
],
}
class RegexSymbolExtractor:
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
Returns SymbolSpans whose end_line is approximated by the next blank line
or next-symbol start (whichever comes first). ponytail: this is an
approximation true block-end requires language-aware brace matching.
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
tree-sitter.
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
patterns = _REGEX_PATTERNS.get(language)
if not patterns:
return []
lines = content.splitlines()
# Collect (line_no, name, kind) tuples first, then compute end_line
# as the line before the next symbol starts (or EOF).
raw_hits: list[tuple[int, str, str]] = []
for kind, pattern in patterns:
for m in pattern.finditer(content):
# Convert match offset to 1-based line number.
line_no = content[: m.start()].count("\n") + 1
raw_hits.append((line_no, m.group(1), kind))
if not raw_hits:
return []
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
seen: set[tuple[int, str]] = set()
unique: list[tuple[int, str, str]] = []
for line_no, name, kind in raw_hits:
key = (line_no, name)
if key in seen:
continue
seen.add(key)
unique.append((line_no, name, kind))
unique.sort(key=lambda x: x[0])
spans: list[SymbolSpan] = []
for i, (start_line, name, kind) in enumerate(unique):
if i + 1 < len(unique):
# End at line before next symbol starts, capped at file length.
end_line = unique[i + 1][0] - 1
else:
end_line = len(lines)
if end_line < start_line:
end_line = start_line
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
return spans
# ---------------------------------------------------------------------------
# Dispatch by file extension
# ---------------------------------------------------------------------------
_EXTENSION_LANGUAGE: dict[str, str] = {
".py": "py",
".ts": "ts",
".tsx": "ts",
".js": "js",
".jsx": "js",
".mjs": "js",
".cjs": "js",
".go": "go",
".rs": "rs",
".java": "java",
}
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
_REGEX_EXTRACTOR = RegexSymbolExtractor()
def language_for_extension(ext: str) -> str:
"""Return the language key for a file extension (with or without leading dot).
Returns "" for unsupported extensions.
"""
if not ext.startswith("."):
ext = "." + ext
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
def get_extractor(language: str) -> SymbolExtractor | None:
"""Return the appropriate extractor for `language`, or None if unsupported."""
if language == "py":
return _DEFAULT_EXTRACTOR
if language in _REGEX_PATTERNS:
return _REGEX_EXTRACTOR
return None
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
"""Read a file and return (symbols, language).
Returns ([], "") if the extension is unsupported or the file cannot be read.
Never raises callers use this for fallback routing.
"""
ext = path.suffix.lower()
language = language_for_extension(ext)
if not language:
return [], ""
try:
content = path.read_text(encoding="utf-8", errors="replace")
except OSError as e:
logger.debug("read failed for %s: %s", path, e)
return [], language
extractor = get_extractor(language)
if extractor is None:
return [], language
return extractor.extract_symbols(content, language), language

View File

@ -0,0 +1,367 @@
"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline.
Per plan U1 Execution note: characterization-first assert that
`symbol=None` returns the full file content (matches pre-existing benchmark
`_FakeTool` shape) before adding symbol-extraction behavior.
"""
from __future__ import annotations
import textwrap
import pytest
from agentkit.tools.file_read import ReadFileTool
# ---------------------------------------------------------------------------
# Schema
# ---------------------------------------------------------------------------
class TestReadFileToolSchema:
def test_name_is_read_file(self):
tool = ReadFileTool()
assert tool.name == "read_file"
def test_required_path(self):
tool = ReadFileTool()
assert "path" in tool.input_schema["required"]
assert "path" in tool.input_schema["properties"]
def test_optional_symbol_and_lines(self):
tool = ReadFileTool()
props = tool.input_schema["properties"]
assert "symbol" in props
assert "start_line" in props
assert "end_line" in props
# None of the optional fields should be in `required`.
required = set(tool.input_schema["required"])
assert required == {"path"}
def test_additional_properties_false(self):
# LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern).
tool = ReadFileTool()
assert tool.input_schema.get("additionalProperties") is False
def test_tags_contain_io_and_read(self):
tool = ReadFileTool()
assert "io" in tool.tags
assert "read" in tool.tags
# ---------------------------------------------------------------------------
# Characterization — symbol=None returns full file
# ---------------------------------------------------------------------------
@pytest.fixture
def sample_py_file(tmp_path):
path = tmp_path / "sample.py"
path.write_text(
textwrap.dedent('''
"""Sample module."""
def my_func():
return 42
class MyClass:
attr = 1
def method_a(self):
return self.attr
''').lstrip(),
encoding="utf-8",
)
return path
@pytest.fixture
def sample_ts_file(tmp_path):
path = tmp_path / "sample.ts"
path.write_text(
textwrap.dedent('''
export function renderComponent(): JSX.Element {
return <div/>;
}
export class BaseService {
abstract run(): void;
}
''').lstrip(),
encoding="utf-8",
)
return path
class TestCharacterizationFullFile:
"""symbol=None returns the whole file (matches _FakeTool baseline)."""
@pytest.mark.asyncio
async def test_full_file_returned_when_symbol_none(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file))
assert result["is_error"] is False
assert result["path"] == str(sample_py_file)
assert result["start_line"] == 1
assert result["end_line"] == result["total_lines"]
assert "def my_func" in result["content"]
assert "class MyClass" in result["content"]
assert result["content"].startswith('"""Sample module."""')
@pytest.mark.asyncio
async def test_full_file_includes_all_lines(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file))
assert result["total_lines"] >= 8
assert result["content"].count("\n") >= result["total_lines"] - 1
# ---------------------------------------------------------------------------
# Symbol slicing — happy paths
# ---------------------------------------------------------------------------
class TestSymbolSlicing:
@pytest.mark.asyncio
async def test_python_function(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="my_func")
assert result["is_error"] is False
assert result["symbol"] == "my_func"
assert result["symbol_kind"] == "function"
assert "def my_func" in result["content"]
assert "return 42" in result["content"]
# Should NOT include the class below.
assert "class MyClass" not in result["content"]
@pytest.mark.asyncio
async def test_python_class_includes_method(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="MyClass")
assert result["is_error"] is False
assert result["symbol"] == "MyClass"
assert result["symbol_kind"] == "class"
assert "class MyClass" in result["content"]
assert "def method_a" in result["content"] # method included
@pytest.mark.asyncio
async def test_python_method_directly(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="method_a")
assert result["is_error"] is False
assert result["symbol"] == "method_a"
assert result["symbol_kind"] == "method"
assert "def method_a" in result["content"]
@pytest.mark.asyncio
async def test_typescript_function(self, sample_ts_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent")
assert result["is_error"] is False
assert result["symbol"] == "renderComponent"
assert "renderComponent" in result["content"]
# Should not include the class below.
assert "BaseService" not in result["content"]
@pytest.mark.asyncio
async def test_typescript_class(self, sample_ts_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_ts_file), symbol="BaseService")
assert result["is_error"] is False
assert result["symbol"] == "BaseService"
assert result["symbol_kind"] == "class"
assert "BaseService" in result["content"]
# ---------------------------------------------------------------------------
# Symbol slicing — edge cases
# ---------------------------------------------------------------------------
class TestSymbolSlicingEdgeCases:
@pytest.mark.asyncio
async def test_symbol_not_found_lists_available(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), symbol="nonexistent")
assert result["is_error"] is False # soft error, not hard
assert result["content"] == ""
assert result["symbol"] == "nonexistent"
available = result["available_symbols"]
assert "my_func" in available
assert "MyClass" in available
assert "method_a" in available
assert "nonexistent" not in result["content"]
@pytest.mark.asyncio
async def test_unsupported_extension_returns_full_with_note(self, tmp_path):
path = tmp_path / "notes.md"
path.write_text("# Hello\nworld\n", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path), symbol="anything")
assert result["is_error"] is False
assert result["content"] == "# Hello\nworld\n"
assert "symbol extraction not supported" in result["note"]
assert ".md" in result["note"]
@pytest.mark.asyncio
async def test_empty_file(self, tmp_path):
path = tmp_path / "empty.py"
path.write_text("", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path))
assert result["is_error"] is False
assert result["content"] == ""
assert result["total_lines"] == 0
@pytest.mark.asyncio
async def test_file_with_no_symbols(self, tmp_path):
path = tmp_path / "data.py"
path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8")
tool = ReadFileTool()
result = await tool.execute(path=str(path), symbol="PI")
# PI is not a def/class — extractor finds no symbols; soft error lists available.
assert result["is_error"] is False
assert result["content"] == ""
assert result["available_symbols"] == []
# ---------------------------------------------------------------------------
# Error paths
# ---------------------------------------------------------------------------
class TestReadFileToolErrors:
@pytest.mark.asyncio
async def test_path_required(self):
tool = ReadFileTool()
result = await tool.execute()
assert result["is_error"] is True
assert "path" in result["error"].lower()
@pytest.mark.asyncio
async def test_path_empty_string(self):
tool = ReadFileTool()
result = await tool.execute(path="")
assert result["is_error"] is True
@pytest.mark.asyncio
async def test_file_not_found(self, tmp_path):
tool = ReadFileTool()
result = await tool.execute(path=str(tmp_path / "missing.py"))
assert result["is_error"] is True
assert "not found" in result["error"].lower()
@pytest.mark.asyncio
async def test_path_is_directory(self, tmp_path):
tool = ReadFileTool()
result = await tool.execute(path=str(tmp_path))
assert result["is_error"] is True
assert "directory" in result["error"].lower()
# ---------------------------------------------------------------------------
# Manual line slicing
# ---------------------------------------------------------------------------
class TestManualLineSlicing:
@pytest.mark.asyncio
async def test_start_and_end_line(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
start_line=3,
end_line=5,
)
assert result["is_error"] is False
assert result["start_line"] == 3
assert result["end_line"] == 5
# Lines 3-5 of the sample file:
# line 3: "def my_func():"
# line 4: " return 42"
# line 5: "" (blank)
assert "def my_func" in result["content"]
assert "return 42" in result["content"]
@pytest.mark.asyncio
async def test_start_line_only_extends_to_eof(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), start_line=8)
assert result["is_error"] is False
assert result["start_line"] == 8
assert result["end_line"] == result["total_lines"]
@pytest.mark.asyncio
async def test_end_line_only_starts_at_one(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), end_line=2)
assert result["is_error"] is False
assert result["start_line"] == 1
assert result["end_line"] == 2
@pytest.mark.asyncio
async def test_invalid_start_line_zero(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(path=str(sample_py_file), start_line=0)
assert result["is_error"] is True
assert "start_line" in result["error"].lower()
@pytest.mark.asyncio
async def test_end_before_start(self, sample_py_file):
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
start_line=5,
end_line=3,
)
assert result["is_error"] is True
assert "end_line" in result["error"].lower()
@pytest.mark.asyncio
async def test_manual_lines_override_symbol(self, sample_py_file):
# Per plan U1 Approach: "start_line/end_line overrides symbol".
tool = ReadFileTool()
result = await tool.execute(
path=str(sample_py_file),
symbol="my_func",
start_line=1,
end_line=1,
)
assert result["is_error"] is False
# Manual slicing won — symbol field absent.
assert "symbol" not in result or result.get("symbol") is None
assert result["start_line"] == 1
assert result["end_line"] == 1
# ---------------------------------------------------------------------------
# Integration — tool registry discovery
# ---------------------------------------------------------------------------
class TestToolRegistryDiscovery:
def test_instantiable_without_args(self):
# Default constructor — matches the convention used by ToolRegistry
# to instantiate tools by class.
tool = ReadFileTool()
assert tool.name == "read_file"
def test_to_dict_serializable(self):
tool = ReadFileTool()
d = tool.to_dict()
assert d["name"] == "read_file"
assert "input_schema" in d
assert "output_schema" in d
assert d["tags"] == ["io", "file", "read"]

View File

@ -0,0 +1,359 @@
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
Covers R22 (file reading supports symbol/function granularity) and KTD1
(Python ast + language-aware regex, no tree-sitter dependency).
"""
from __future__ import annotations
import textwrap
import pytest
from agentkit.tools.symbol_extractor import (
AstSymbolExtractor,
RegexSymbolExtractor,
SymbolSpan,
extract_symbols_from_file,
get_extractor,
language_for_extension,
)
# ---------------------------------------------------------------------------
# language_for_extension
# ---------------------------------------------------------------------------
class TestLanguageForExtension:
def test_python_extensions(self):
assert language_for_extension("py") == "py"
assert language_for_extension(".py") == "py"
assert language_for_extension(".PY") == "py" # case-insensitive
def test_typescript_javascript(self):
assert language_for_extension(".ts") == "ts"
assert language_for_extension(".tsx") == "ts"
assert language_for_extension(".js") == "js"
assert language_for_extension(".jsx") == "js"
assert language_for_extension(".mjs") == "js"
assert language_for_extension(".cjs") == "js"
def test_go_rust_java(self):
assert language_for_extension(".go") == "go"
assert language_for_extension(".rs") == "rs"
assert language_for_extension(".java") == "java"
def test_unsupported_returns_empty(self):
assert language_for_extension(".md") == ""
assert language_for_extension(".txt") == ""
assert language_for_extension("") == ""
assert language_for_extension(".unknown") == ""
# ---------------------------------------------------------------------------
# AstSymbolExtractor — Python
# ---------------------------------------------------------------------------
class TestAstSymbolExtractor:
extractor = AstSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
def test_syntax_error_returns_empty(self):
# Never raises — callers rely on this for fallback routing.
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
assert result == []
def test_top_level_function(self):
content = "def my_func():\n return 42\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
span = spans[0]
assert span.name == "my_func"
assert span.kind == "function"
assert span.start_line == 1
assert span.end_line == 2
def test_async_function(self):
content = "async def fetch():\n return 1\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
assert spans[0].name == "fetch"
assert spans[0].kind == "function"
def test_top_level_class(self):
content = textwrap.dedent('''
class MyClass:
"""docstring"""
def method_a(self):
return 1
async def method_b(self):
return 2
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = [s.name for s in spans]
assert "MyClass" in names
assert "method_a" in names
assert "method_b" in names
cls = next(s for s in spans if s.name == "MyClass")
assert cls.kind == "class"
assert cls.start_line == 1
# Class body extends through the last method's end_lineno.
assert cls.end_line >= 7
def test_methods_classified_as_methods(self):
content = textwrap.dedent('''
class Foo:
def bar(self):
pass
def top_level():
pass
''').strip()
spans = self.extractor.extract_symbols(content, "py")
by_name = {s.name: s for s in spans}
assert by_name["bar"].kind == "method"
assert by_name["top_level"].kind == "function"
def test_decorated_function(self):
content = textwrap.dedent('''
@staticmethod
def helper():
return "hi"
''').strip()
spans = self.extractor.extract_symbols(content, "py")
# Note: extractor uses node.lineno (def line) — decorators above are
# excluded by design (matches user-visible symbol start at `def`).
assert any(s.name == "helper" for s in spans)
span = next(s for s in spans if s.name == "helper")
assert span.start_line == 2 # the `def` line
def test_nested_function(self):
content = textwrap.dedent('''
def outer():
def inner():
return 1
return inner()
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = {s.name for s in spans}
assert "outer" in names
assert "inner" in names
def test_empty_file(self):
assert self.extractor.extract_symbols("", "py") == []
def test_no_symbols_in_docstring_only_file(self):
content = '"""just a docstring"""\n'
assert self.extractor.extract_symbols(content, "py") == []
# ---------------------------------------------------------------------------
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
class TestRegexSymbolExtractor:
extractor = RegexSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
def test_typescript_function_declaration(self):
content = textwrap.dedent('''
export function renderComponent(props: Props): JSX.Element {
return <div/>;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
def test_typescript_async_function(self):
content = "async function fetchData() {\n return await fetch();\n}\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "fetchData" for s in spans)
def test_typescript_arrow_function_const(self):
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "handleClick" for s in spans)
def test_typescript_class(self):
content = textwrap.dedent('''
export abstract class BaseService {
abstract run(): void;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
def test_javascript_function(self):
content = "function foo() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "foo" for s in spans)
def test_javascript_arrow_const(self):
content = "const bar = () => 42;\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "bar" for s in spans)
def test_go_function(self):
content = textwrap.dedent('''
package main
func HandleRequest(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}
func (s *Server) Start() {
// method
}
''').strip()
spans = self.extractor.extract_symbols(content, "go")
names = {s.name for s in spans}
assert "HandleRequest" in names
assert "Start" in names # method receiver pattern
def test_go_struct(self):
content = "type Server struct {\n Addr string\n}\n"
spans = self.extractor.extract_symbols(content, "go")
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
def test_rust_function(self):
content = textwrap.dedent('''
pub fn process(input: &str) -> Result<usize, Error> {
Ok(input.len())
}
async fn fetch() -> Bytes {
unimplemented!()
}
''').strip()
spans = self.extractor.extract_symbols(content, "rs")
names = {s.name for s in spans}
assert "process" in names
assert "fetch" in names
def test_rust_struct(self):
content = "pub struct Config {\n pub path: String,\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
def test_rust_impl(self):
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
def test_java_class(self):
content = textwrap.dedent('''
package com.example;
public class UserService {
public User findById(long id) {
return null;
}
}
''').strip()
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
def test_java_method(self):
content = "public User findById(long id) {\n return null;\n}\n"
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "findById" and s.kind == "function" for s in spans)
def test_end_line_extends_to_next_symbol(self):
# First symbol's end_line is the line before the second symbol starts.
content = textwrap.dedent('''
function first() {
return 1;
}
function second() {
return 2;
}
''').strip()
spans = self.extractor.extract_symbols(content, "js")
spans.sort(key=lambda s: s.start_line)
first = spans[0]
second = spans[1]
assert first.name == "first"
assert second.name == "second"
assert first.end_line == second.start_line - 1
def test_last_symbol_end_line_is_eof(self):
content = "function only() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert len(spans) == 1
assert spans[0].end_line == len(content.splitlines())
# ---------------------------------------------------------------------------
# get_extractor + integration
# ---------------------------------------------------------------------------
class TestGetExtractor:
def test_python_returns_ast_extractor(self):
ext = get_extractor("py")
assert ext is not None
assert isinstance(ext, AstSymbolExtractor)
def test_typescript_returns_regex_extractor(self):
ext = get_extractor("ts")
assert ext is not None
assert isinstance(ext, RegexSymbolExtractor)
def test_unsupported_returns_none(self):
assert get_extractor("md") is None
assert get_extractor("") is None
assert get_extractor("unknown") is None
class TestExtractSymbolsFromFile:
def test_python_file(self, tmp_path):
path = tmp_path / "module.py"
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == "py"
assert any(s.name == "hello" for s in spans)
def test_unsupported_extension(self, tmp_path):
path = tmp_path / "notes.md"
path.write_text("# Hello\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == ""
assert spans == []
def test_missing_file_returns_empty(self, tmp_path):
path = tmp_path / "nonexistent.py"
spans, lang = extract_symbols_from_file(path)
# lang is detected from extension even if read fails.
assert lang == "py"
assert spans == []
# ---------------------------------------------------------------------------
# SymbolSpan dataclass
# ---------------------------------------------------------------------------
class TestSymbolSpan:
def test_frozen_dataclass(self):
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
assert span.name == "foo"
with pytest.raises(Exception):
span.name = "bar" # type: ignore[misc] — frozen
def test_equality(self):
a = SymbolSpan("foo", "function", 1, 3)
b = SymbolSpan("foo", "function", 1, 3)
assert a == b
assert hash(a) == hash(b)