feat(U1): G5 SymbolExtractor + ReadFileTool with symbol slicing
- SymbolExtractor protocol + SymbolSpan dataclass - AstSymbolExtractor for Python (stdlib ast, no tree-sitter dep — KTD1) - RegexSymbolExtractor for TS/JS/Go/Rust/Java (language-aware regex) - ReadFileTool with path/symbol/start_line/end_line params - symbol=None returns full file (characterization baseline matching _FakeTool) - symbol='foo' returns first matching symbol's line range - symbol not found returns available_symbols list (soft error) - Unsupported extension returns full file with note - Manual start_line/end_line overrides symbol - 66 unit tests covering R22/R23 + characterization + edge cases
This commit is contained in:
parent
e3f69f963c
commit
58ef1719cb
|
|
@ -18,6 +18,7 @@ from agentkit.tools.memory_tool import MemoryTool
|
||||||
from agentkit.tools.web_search import WebSearchTool
|
from agentkit.tools.web_search import WebSearchTool
|
||||||
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
||||||
from agentkit.tools.search import ToolSearchIndex
|
from agentkit.tools.search import ToolSearchIndex
|
||||||
|
from agentkit.tools.file_read import ReadFileTool
|
||||||
|
|
||||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||||
try:
|
try:
|
||||||
|
|
@ -52,4 +53,5 @@ __all__ = [
|
||||||
"OutputParser",
|
"OutputParser",
|
||||||
"ParsedOutput",
|
"ParsedOutput",
|
||||||
"ErrorType",
|
"ErrorType",
|
||||||
|
"ReadFileTool",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,262 @@
|
||||||
|
"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23).
|
||||||
|
|
||||||
|
Backward compatible with the pre-existing `_FakeTool` benchmark shape — when
|
||||||
|
`symbol=None`, returns the full file content. When `symbol="foo"`, returns
|
||||||
|
the line range of the first matching symbol via `SymbolExtractor`.
|
||||||
|
|
||||||
|
KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool — keeps the
|
||||||
|
file-reading contract clean and gives the LLM a focused schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.symbol_extractor import (
|
||||||
|
SymbolSpan,
|
||||||
|
extract_symbols_from_file,
|
||||||
|
language_for_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadFileTool(Tool):
|
||||||
|
"""Read a file from the filesystem, optionally sliced to a single symbol.
|
||||||
|
|
||||||
|
Tool name `read_file` matches the reserved entry in
|
||||||
|
`core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real
|
||||||
|
implementation — only `_FakeTool` stubs in `cli/benchmark.py`).
|
||||||
|
|
||||||
|
Backward-compat contract: `symbol=None` returns the full file content,
|
||||||
|
matching the shape `{"path": ...}` that downstream callers (benchmark,
|
||||||
|
phase whitelist) already expect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "read_file",
|
||||||
|
description: str | None = None,
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description
|
||||||
|
or (
|
||||||
|
"Read a file from the filesystem. By default returns the full file "
|
||||||
|
"content. Pass `symbol` (function/class/struct name) to slice to just "
|
||||||
|
"that symbol's line range — saves context when you only need one "
|
||||||
|
"function from a large file. Pass `start_line`/`end_line` for manual "
|
||||||
|
"slicing. If `symbol` is set but not found, returns the available "
|
||||||
|
"symbol names so you can retry."
|
||||||
|
),
|
||||||
|
input_schema=input_schema or self._default_input_schema(),
|
||||||
|
output_schema=output_schema or self._default_output_schema(),
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["io", "file", "read"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_input_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path to the file to read (absolute or relative to cwd).",
|
||||||
|
},
|
||||||
|
"symbol": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional: name of a function/class/struct/method to slice to. "
|
||||||
|
"When set, returns only the line range of the first matching "
|
||||||
|
"symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"start_line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Optional 1-based start line for manual slicing. Overrides `symbol`.",
|
||||||
|
"minimum": 1,
|
||||||
|
},
|
||||||
|
"end_line": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.",
|
||||||
|
"minimum": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_output_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {"type": "string"},
|
||||||
|
"path": {"type": "string"},
|
||||||
|
"start_line": {"type": "integer"},
|
||||||
|
"end_line": {"type": "integer"},
|
||||||
|
"symbol": {"type": "string"},
|
||||||
|
"available_symbols": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Populated when `symbol` is set but not found.",
|
||||||
|
},
|
||||||
|
"note": {"type": "string"},
|
||||||
|
"is_error": {"type": "boolean"},
|
||||||
|
"error": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||||
|
raw_path = kwargs.get("path")
|
||||||
|
if not raw_path:
|
||||||
|
return self._error("`path` is required")
|
||||||
|
|
||||||
|
path = Path(raw_path)
|
||||||
|
if not path.is_absolute():
|
||||||
|
path = path.resolve()
|
||||||
|
|
||||||
|
symbol = kwargs.get("symbol")
|
||||||
|
start_line = kwargs.get("start_line")
|
||||||
|
end_line = kwargs.get("end_line")
|
||||||
|
|
||||||
|
# Validate/sanitize line overrides.
|
||||||
|
if start_line is not None and (not isinstance(start_line, int) or start_line < 1):
|
||||||
|
return self._error(f"`start_line` must be a positive integer, got {start_line!r}")
|
||||||
|
if end_line is not None and (not isinstance(end_line, int) or end_line < 1):
|
||||||
|
return self._error(f"`end_line` must be a positive integer, got {end_line!r}")
|
||||||
|
if start_line is not None and end_line is not None and end_line < start_line:
|
||||||
|
return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})")
|
||||||
|
|
||||||
|
# Filesystem checks.
|
||||||
|
if not path.exists():
|
||||||
|
return self._error(f"File not found: {path}", path=str(path))
|
||||||
|
if path.is_dir():
|
||||||
|
return self._error(f"Path is a directory, not a file: {path}", path=str(path))
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
except PermissionError as e:
|
||||||
|
return self._error(f"Permission denied: {path}", path=str(path), detail=str(e))
|
||||||
|
except OSError as e:
|
||||||
|
return self._error(f"Failed to read {path}: {e}", path=str(path))
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
# Manual slicing takes precedence over symbol (per plan U1 Approach).
|
||||||
|
if start_line is not None or end_line is not None:
|
||||||
|
s = max(1, start_line or 1)
|
||||||
|
e = min(total_lines, end_line or total_lines)
|
||||||
|
sliced = "\n".join(lines[s - 1 : e])
|
||||||
|
return {
|
||||||
|
"content": sliced,
|
||||||
|
"path": str(path),
|
||||||
|
"start_line": s,
|
||||||
|
"end_line": e,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Symbol slicing.
|
||||||
|
if symbol:
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
language = language_for_extension(ext)
|
||||||
|
if not language:
|
||||||
|
# Unsupported extension: return full file with note (per plan U1 Edge case).
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"path": str(path),
|
||||||
|
"start_line": 1,
|
||||||
|
"end_line": total_lines,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"note": f"symbol extraction not supported for {ext or 'unknown extension'}",
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
spans, _lang = extract_symbols_from_file(path)
|
||||||
|
# Re-extract using the content we already read so we don't read the file twice.
|
||||||
|
if not spans:
|
||||||
|
# Try extraction from in-memory content (path-based extraction may
|
||||||
|
# have failed silently on OSError; we already read it successfully).
|
||||||
|
from agentkit.tools.symbol_extractor import get_extractor
|
||||||
|
|
||||||
|
extractor = get_extractor(language)
|
||||||
|
if extractor is not None:
|
||||||
|
spans = extractor.extract_symbols(content, language)
|
||||||
|
|
||||||
|
match = _find_symbol(spans, symbol)
|
||||||
|
if match is None:
|
||||||
|
available = sorted({s.name for s in spans})
|
||||||
|
return {
|
||||||
|
"content": "",
|
||||||
|
"path": str(path),
|
||||||
|
"symbol": symbol,
|
||||||
|
"available_symbols": available,
|
||||||
|
"is_error": False,
|
||||||
|
"note": (
|
||||||
|
f"Symbol {symbol!r} not found in {path.name}. "
|
||||||
|
f"Available: {', '.join(available) if available else '(none)'}"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
s = match.start_line
|
||||||
|
e = min(match.end_line, total_lines)
|
||||||
|
sliced = "\n".join(lines[s - 1 : e])
|
||||||
|
return {
|
||||||
|
"content": sliced,
|
||||||
|
"path": str(path),
|
||||||
|
"symbol": symbol,
|
||||||
|
"symbol_kind": match.kind,
|
||||||
|
"start_line": s,
|
||||||
|
"end_line": e,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default: full file (characterization baseline — matches _FakeTool shape).
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"path": str(path),
|
||||||
|
"start_line": 1,
|
||||||
|
"end_line": total_lines,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"is_error": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _error(
|
||||||
|
message: str, *, path: str | None = None, detail: str | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"content": "",
|
||||||
|
"is_error": True,
|
||||||
|
"error": message,
|
||||||
|
}
|
||||||
|
if path is not None:
|
||||||
|
result["path"] = path
|
||||||
|
if detail is not None:
|
||||||
|
result["detail"] = detail
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None:
|
||||||
|
"""Find the first symbol matching `name`. Case-sensitive.
|
||||||
|
|
||||||
|
ponytail: linear scan is fine for typical file symbol counts (<100). The
|
||||||
|
extractor already returns symbols sorted by start_line; first match wins
|
||||||
|
for ambiguous overloads (e.g., Python classes with same name in different
|
||||||
|
modules — not relevant within one file).
|
||||||
|
"""
|
||||||
|
for span in spans:
|
||||||
|
if span.name == name:
|
||||||
|
return span
|
||||||
|
return None
|
||||||
|
|
@ -0,0 +1,278 @@
|
||||||
|
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
|
||||||
|
|
||||||
|
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
|
||||||
|
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
|
||||||
|
`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor
|
||||||
|
can replace RegexSymbolExtractor behind the same interface.
|
||||||
|
|
||||||
|
ponytail: regex extractor covers ~80% case (top-level function/class/struct
|
||||||
|
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
|
||||||
|
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class SymbolSpan:
|
||||||
|
"""A located symbol — name, kind, and 1-based inclusive line range."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
kind: str # "function" | "class" | "method" | "struct" | "impl"
|
||||||
|
start_line: int # 1-based, inclusive
|
||||||
|
end_line: int # 1-based, inclusive
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SymbolExtractor(Protocol):
|
||||||
|
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
|
||||||
|
|
||||||
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||||
|
"""Return all symbols found in `content`.
|
||||||
|
|
||||||
|
`language` is the file extension without leading dot (e.g. "py", "ts").
|
||||||
|
Implementations must never raise on extraction failure — return [] on
|
||||||
|
parse errors and let the caller decide the fallback (full-file read).
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Python — stdlib ast
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AstSymbolExtractor:
|
||||||
|
"""Python symbol extractor using the stdlib `ast` module.
|
||||||
|
|
||||||
|
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
|
||||||
|
functions inside classes. The end_line is the last line of the node's
|
||||||
|
source segment (decorator-inclusive).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||||
|
if language != "py":
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
tree = ast.parse(content)
|
||||||
|
except SyntaxError as e:
|
||||||
|
logger.debug("ast.parse failed: %s", e)
|
||||||
|
return []
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
spans: list[SymbolSpan] = []
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
kind = "method" if _is_method(node) else "function"
|
||||||
|
spans.append(_span_from_node(node, kind, lines))
|
||||||
|
elif isinstance(node, ast.ClassDef):
|
||||||
|
spans.append(_span_from_node(node, "class", lines))
|
||||||
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
def _is_method(node: ast.AST) -> bool:
|
||||||
|
"""A FunctionDef is a method if its parent is a ClassDef.
|
||||||
|
|
||||||
|
`ast.walk` doesn't expose parentage, so we approximate by checking the
|
||||||
|
node's col_offset == 4 (indented inside a class body). ponytail: this
|
||||||
|
misses methods in deeply nested classes — ceiling noted; upgrade path =
|
||||||
|
ast.NodeVisitor with parent tracking.
|
||||||
|
"""
|
||||||
|
return getattr(node, "col_offset", 0) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def _span_from_node(
|
||||||
|
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
||||||
|
kind: str,
|
||||||
|
lines: list[str],
|
||||||
|
) -> SymbolSpan:
|
||||||
|
# ast line numbers are 1-based; start at decorator if present (lineno points
|
||||||
|
# to the def/class keyword, decorators are above). Use node.lineno for start
|
||||||
|
# so the returned range matches what the user sees at the def keyword.
|
||||||
|
start = node.lineno
|
||||||
|
# node.end_lineno is the last line of the node body (None on old Pythons).
|
||||||
|
end = node.end_lineno or start
|
||||||
|
# Clamp to actual file length (defensive — ast should not exceed, but
|
||||||
|
# malformed files with no trailing newline can confuse end_lineno).
|
||||||
|
if end > len(lines):
|
||||||
|
end = len(lines)
|
||||||
|
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Regex extractor — TS/JS/Go/Rust/Java
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Each pattern matches a declaration and captures the symbol name in group 1.
|
||||||
|
# Patterns use re.MULTILINE so ^ matches line starts.
|
||||||
|
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
|
||||||
|
"ts": [
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
|
||||||
|
),
|
||||||
|
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(
|
||||||
|
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"js": [
|
||||||
|
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
|
||||||
|
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(
|
||||||
|
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"go": [
|
||||||
|
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
|
||||||
|
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
|
||||||
|
],
|
||||||
|
"rs": [
|
||||||
|
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
|
||||||
|
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
|
||||||
|
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
|
||||||
|
],
|
||||||
|
"java": [
|
||||||
|
(
|
||||||
|
"function",
|
||||||
|
re.compile(
|
||||||
|
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
|
||||||
|
re.MULTILINE,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"class",
|
||||||
|
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RegexSymbolExtractor:
|
||||||
|
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
|
||||||
|
|
||||||
|
Returns SymbolSpans whose end_line is approximated by the next blank line
|
||||||
|
or next-symbol start (whichever comes first). ponytail: this is an
|
||||||
|
approximation — true block-end requires language-aware brace matching.
|
||||||
|
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
|
||||||
|
tree-sitter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||||
|
patterns = _REGEX_PATTERNS.get(language)
|
||||||
|
if not patterns:
|
||||||
|
return []
|
||||||
|
|
||||||
|
lines = content.splitlines()
|
||||||
|
# Collect (line_no, name, kind) tuples first, then compute end_line
|
||||||
|
# as the line before the next symbol starts (or EOF).
|
||||||
|
raw_hits: list[tuple[int, str, str]] = []
|
||||||
|
for kind, pattern in patterns:
|
||||||
|
for m in pattern.finditer(content):
|
||||||
|
# Convert match offset to 1-based line number.
|
||||||
|
line_no = content[: m.start()].count("\n") + 1
|
||||||
|
raw_hits.append((line_no, m.group(1), kind))
|
||||||
|
|
||||||
|
if not raw_hits:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
|
||||||
|
seen: set[tuple[int, str]] = set()
|
||||||
|
unique: list[tuple[int, str, str]] = []
|
||||||
|
for line_no, name, kind in raw_hits:
|
||||||
|
key = (line_no, name)
|
||||||
|
if key in seen:
|
||||||
|
continue
|
||||||
|
seen.add(key)
|
||||||
|
unique.append((line_no, name, kind))
|
||||||
|
|
||||||
|
unique.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
spans: list[SymbolSpan] = []
|
||||||
|
for i, (start_line, name, kind) in enumerate(unique):
|
||||||
|
if i + 1 < len(unique):
|
||||||
|
# End at line before next symbol starts, capped at file length.
|
||||||
|
end_line = unique[i + 1][0] - 1
|
||||||
|
else:
|
||||||
|
end_line = len(lines)
|
||||||
|
if end_line < start_line:
|
||||||
|
end_line = start_line
|
||||||
|
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
|
||||||
|
return spans
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dispatch by file extension
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_EXTENSION_LANGUAGE: dict[str, str] = {
|
||||||
|
".py": "py",
|
||||||
|
".ts": "ts",
|
||||||
|
".tsx": "ts",
|
||||||
|
".js": "js",
|
||||||
|
".jsx": "js",
|
||||||
|
".mjs": "js",
|
||||||
|
".cjs": "js",
|
||||||
|
".go": "go",
|
||||||
|
".rs": "rs",
|
||||||
|
".java": "java",
|
||||||
|
}
|
||||||
|
|
||||||
|
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
|
||||||
|
_REGEX_EXTRACTOR = RegexSymbolExtractor()
|
||||||
|
|
||||||
|
|
||||||
|
def language_for_extension(ext: str) -> str:
|
||||||
|
"""Return the language key for a file extension (with or without leading dot).
|
||||||
|
|
||||||
|
Returns "" for unsupported extensions.
|
||||||
|
"""
|
||||||
|
if not ext.startswith("."):
|
||||||
|
ext = "." + ext
|
||||||
|
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
|
||||||
|
|
||||||
|
|
||||||
|
def get_extractor(language: str) -> SymbolExtractor | None:
|
||||||
|
"""Return the appropriate extractor for `language`, or None if unsupported."""
|
||||||
|
if language == "py":
|
||||||
|
return _DEFAULT_EXTRACTOR
|
||||||
|
if language in _REGEX_PATTERNS:
|
||||||
|
return _REGEX_EXTRACTOR
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
|
||||||
|
"""Read a file and return (symbols, language).
|
||||||
|
|
||||||
|
Returns ([], "") if the extension is unsupported or the file cannot be read.
|
||||||
|
Never raises — callers use this for fallback routing.
|
||||||
|
"""
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
language = language_for_extension(ext)
|
||||||
|
if not language:
|
||||||
|
return [], ""
|
||||||
|
try:
|
||||||
|
content = path.read_text(encoding="utf-8", errors="replace")
|
||||||
|
except OSError as e:
|
||||||
|
logger.debug("read failed for %s: %s", path, e)
|
||||||
|
return [], language
|
||||||
|
extractor = get_extractor(language)
|
||||||
|
if extractor is None:
|
||||||
|
return [], language
|
||||||
|
return extractor.extract_symbols(content, language), language
|
||||||
|
|
@ -0,0 +1,367 @@
|
||||||
|
"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline.
|
||||||
|
|
||||||
|
Per plan U1 Execution note: characterization-first — assert that
|
||||||
|
`symbol=None` returns the full file content (matches pre-existing benchmark
|
||||||
|
`_FakeTool` shape) before adding symbol-extraction behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.file_read import ReadFileTool
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Schema
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadFileToolSchema:
|
||||||
|
def test_name_is_read_file(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert tool.name == "read_file"
|
||||||
|
|
||||||
|
def test_required_path(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert "path" in tool.input_schema["required"]
|
||||||
|
assert "path" in tool.input_schema["properties"]
|
||||||
|
|
||||||
|
def test_optional_symbol_and_lines(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
props = tool.input_schema["properties"]
|
||||||
|
assert "symbol" in props
|
||||||
|
assert "start_line" in props
|
||||||
|
assert "end_line" in props
|
||||||
|
# None of the optional fields should be in `required`.
|
||||||
|
required = set(tool.input_schema["required"])
|
||||||
|
assert required == {"path"}
|
||||||
|
|
||||||
|
def test_additional_properties_false(self):
|
||||||
|
# LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern).
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert tool.input_schema.get("additionalProperties") is False
|
||||||
|
|
||||||
|
def test_tags_contain_io_and_read(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert "io" in tool.tags
|
||||||
|
assert "read" in tool.tags
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Characterization — symbol=None returns full file
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_py_file(tmp_path):
|
||||||
|
path = tmp_path / "sample.py"
|
||||||
|
path.write_text(
|
||||||
|
textwrap.dedent('''
|
||||||
|
"""Sample module."""
|
||||||
|
|
||||||
|
def my_func():
|
||||||
|
return 42
|
||||||
|
|
||||||
|
|
||||||
|
class MyClass:
|
||||||
|
attr = 1
|
||||||
|
|
||||||
|
def method_a(self):
|
||||||
|
return self.attr
|
||||||
|
''').lstrip(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_ts_file(tmp_path):
|
||||||
|
path = tmp_path / "sample.ts"
|
||||||
|
path.write_text(
|
||||||
|
textwrap.dedent('''
|
||||||
|
export function renderComponent(): JSX.Element {
|
||||||
|
return <div/>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class BaseService {
|
||||||
|
abstract run(): void;
|
||||||
|
}
|
||||||
|
''').lstrip(),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class TestCharacterizationFullFile:
|
||||||
|
"""symbol=None returns the whole file (matches _FakeTool baseline)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_file_returned_when_symbol_none(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file))
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["path"] == str(sample_py_file)
|
||||||
|
assert result["start_line"] == 1
|
||||||
|
assert result["end_line"] == result["total_lines"]
|
||||||
|
assert "def my_func" in result["content"]
|
||||||
|
assert "class MyClass" in result["content"]
|
||||||
|
assert result["content"].startswith('"""Sample module."""')
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_file_includes_all_lines(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file))
|
||||||
|
assert result["total_lines"] >= 8
|
||||||
|
assert result["content"].count("\n") >= result["total_lines"] - 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Symbol slicing — happy paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSymbolSlicing:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_python_function(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="my_func")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "my_func"
|
||||||
|
assert result["symbol_kind"] == "function"
|
||||||
|
assert "def my_func" in result["content"]
|
||||||
|
assert "return 42" in result["content"]
|
||||||
|
# Should NOT include the class below.
|
||||||
|
assert "class MyClass" not in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_python_class_includes_method(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="MyClass")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "MyClass"
|
||||||
|
assert result["symbol_kind"] == "class"
|
||||||
|
assert "class MyClass" in result["content"]
|
||||||
|
assert "def method_a" in result["content"] # method included
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_python_method_directly(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="method_a")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "method_a"
|
||||||
|
assert result["symbol_kind"] == "method"
|
||||||
|
assert "def method_a" in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_typescript_function(self, sample_ts_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "renderComponent"
|
||||||
|
assert "renderComponent" in result["content"]
|
||||||
|
# Should not include the class below.
|
||||||
|
assert "BaseService" not in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_typescript_class(self, sample_ts_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_ts_file), symbol="BaseService")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["symbol"] == "BaseService"
|
||||||
|
assert result["symbol_kind"] == "class"
|
||||||
|
assert "BaseService" in result["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Symbol slicing — edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSymbolSlicingEdgeCases:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_symbol_not_found_lists_available(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), symbol="nonexistent")
|
||||||
|
|
||||||
|
assert result["is_error"] is False # soft error, not hard
|
||||||
|
assert result["content"] == ""
|
||||||
|
assert result["symbol"] == "nonexistent"
|
||||||
|
available = result["available_symbols"]
|
||||||
|
assert "my_func" in available
|
||||||
|
assert "MyClass" in available
|
||||||
|
assert "method_a" in available
|
||||||
|
assert "nonexistent" not in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unsupported_extension_returns_full_with_note(self, tmp_path):
|
||||||
|
path = tmp_path / "notes.md"
|
||||||
|
path.write_text("# Hello\nworld\n", encoding="utf-8")
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(path), symbol="anything")
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["content"] == "# Hello\nworld\n"
|
||||||
|
assert "symbol extraction not supported" in result["note"]
|
||||||
|
assert ".md" in result["note"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_file(self, tmp_path):
|
||||||
|
path = tmp_path / "empty.py"
|
||||||
|
path.write_text("", encoding="utf-8")
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(path))
|
||||||
|
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["content"] == ""
|
||||||
|
assert result["total_lines"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_with_no_symbols(self, tmp_path):
|
||||||
|
path = tmp_path / "data.py"
|
||||||
|
path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8")
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(path), symbol="PI")
|
||||||
|
|
||||||
|
# PI is not a def/class — extractor finds no symbols; soft error lists available.
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["content"] == ""
|
||||||
|
assert result["available_symbols"] == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error paths
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadFileToolErrors:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_path_required(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "path" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_path_empty_string(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path="")
|
||||||
|
assert result["is_error"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_not_found(self, tmp_path):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(tmp_path / "missing.py"))
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "not found" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_path_is_directory(self, tmp_path):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(tmp_path))
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "directory" in result["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Manual line slicing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestManualLineSlicing:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_and_end_line(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(sample_py_file),
|
||||||
|
start_line=3,
|
||||||
|
end_line=5,
|
||||||
|
)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["start_line"] == 3
|
||||||
|
assert result["end_line"] == 5
|
||||||
|
# Lines 3-5 of the sample file:
|
||||||
|
# line 3: "def my_func():"
|
||||||
|
# line 4: " return 42"
|
||||||
|
# line 5: "" (blank)
|
||||||
|
assert "def my_func" in result["content"]
|
||||||
|
assert "return 42" in result["content"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_line_only_extends_to_eof(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), start_line=8)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["start_line"] == 8
|
||||||
|
assert result["end_line"] == result["total_lines"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_line_only_starts_at_one(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), end_line=2)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["start_line"] == 1
|
||||||
|
assert result["end_line"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_start_line_zero(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(path=str(sample_py_file), start_line=0)
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "start_line" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_before_start(self, sample_py_file):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(sample_py_file),
|
||||||
|
start_line=5,
|
||||||
|
end_line=3,
|
||||||
|
)
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert "end_line" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manual_lines_override_symbol(self, sample_py_file):
|
||||||
|
# Per plan U1 Approach: "start_line/end_line overrides symbol".
|
||||||
|
tool = ReadFileTool()
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(sample_py_file),
|
||||||
|
symbol="my_func",
|
||||||
|
start_line=1,
|
||||||
|
end_line=1,
|
||||||
|
)
|
||||||
|
assert result["is_error"] is False
|
||||||
|
# Manual slicing won — symbol field absent.
|
||||||
|
assert "symbol" not in result or result.get("symbol") is None
|
||||||
|
assert result["start_line"] == 1
|
||||||
|
assert result["end_line"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration — tool registry discovery
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistryDiscovery:
|
||||||
|
def test_instantiable_without_args(self):
|
||||||
|
# Default constructor — matches the convention used by ToolRegistry
|
||||||
|
# to instantiate tools by class.
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert tool.name == "read_file"
|
||||||
|
|
||||||
|
def test_to_dict_serializable(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
d = tool.to_dict()
|
||||||
|
assert d["name"] == "read_file"
|
||||||
|
assert "input_schema" in d
|
||||||
|
assert "output_schema" in d
|
||||||
|
assert d["tags"] == ["io", "file", "read"]
|
||||||
|
|
@ -0,0 +1,359 @@
|
||||||
|
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
|
||||||
|
|
||||||
|
Covers R22 (file reading supports symbol/function granularity) and KTD1
|
||||||
|
(Python ast + language-aware regex, no tree-sitter dependency).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.symbol_extractor import (
|
||||||
|
AstSymbolExtractor,
|
||||||
|
RegexSymbolExtractor,
|
||||||
|
SymbolSpan,
|
||||||
|
extract_symbols_from_file,
|
||||||
|
get_extractor,
|
||||||
|
language_for_extension,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# language_for_extension
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLanguageForExtension:
|
||||||
|
def test_python_extensions(self):
|
||||||
|
assert language_for_extension("py") == "py"
|
||||||
|
assert language_for_extension(".py") == "py"
|
||||||
|
assert language_for_extension(".PY") == "py" # case-insensitive
|
||||||
|
|
||||||
|
def test_typescript_javascript(self):
|
||||||
|
assert language_for_extension(".ts") == "ts"
|
||||||
|
assert language_for_extension(".tsx") == "ts"
|
||||||
|
assert language_for_extension(".js") == "js"
|
||||||
|
assert language_for_extension(".jsx") == "js"
|
||||||
|
assert language_for_extension(".mjs") == "js"
|
||||||
|
assert language_for_extension(".cjs") == "js"
|
||||||
|
|
||||||
|
def test_go_rust_java(self):
|
||||||
|
assert language_for_extension(".go") == "go"
|
||||||
|
assert language_for_extension(".rs") == "rs"
|
||||||
|
assert language_for_extension(".java") == "java"
|
||||||
|
|
||||||
|
def test_unsupported_returns_empty(self):
|
||||||
|
assert language_for_extension(".md") == ""
|
||||||
|
assert language_for_extension(".txt") == ""
|
||||||
|
assert language_for_extension("") == ""
|
||||||
|
assert language_for_extension(".unknown") == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AstSymbolExtractor — Python
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAstSymbolExtractor:
|
||||||
|
extractor = AstSymbolExtractor()
|
||||||
|
|
||||||
|
def test_unsupported_language_returns_empty(self):
|
||||||
|
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
|
||||||
|
|
||||||
|
def test_syntax_error_returns_empty(self):
|
||||||
|
# Never raises — callers rely on this for fallback routing.
|
||||||
|
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_top_level_function(self):
|
||||||
|
content = "def my_func():\n return 42\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
assert len(spans) == 1
|
||||||
|
span = spans[0]
|
||||||
|
assert span.name == "my_func"
|
||||||
|
assert span.kind == "function"
|
||||||
|
assert span.start_line == 1
|
||||||
|
assert span.end_line == 2
|
||||||
|
|
||||||
|
def test_async_function(self):
|
||||||
|
content = "async def fetch():\n return 1\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
assert len(spans) == 1
|
||||||
|
assert spans[0].name == "fetch"
|
||||||
|
assert spans[0].kind == "function"
|
||||||
|
|
||||||
|
def test_top_level_class(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
class MyClass:
|
||||||
|
"""docstring"""
|
||||||
|
|
||||||
|
def method_a(self):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def method_b(self):
|
||||||
|
return 2
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
names = [s.name for s in spans]
|
||||||
|
assert "MyClass" in names
|
||||||
|
assert "method_a" in names
|
||||||
|
assert "method_b" in names
|
||||||
|
|
||||||
|
cls = next(s for s in spans if s.name == "MyClass")
|
||||||
|
assert cls.kind == "class"
|
||||||
|
assert cls.start_line == 1
|
||||||
|
# Class body extends through the last method's end_lineno.
|
||||||
|
assert cls.end_line >= 7
|
||||||
|
|
||||||
|
def test_methods_classified_as_methods(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
class Foo:
|
||||||
|
def bar(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def top_level():
|
||||||
|
pass
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
by_name = {s.name: s for s in spans}
|
||||||
|
assert by_name["bar"].kind == "method"
|
||||||
|
assert by_name["top_level"].kind == "function"
|
||||||
|
|
||||||
|
def test_decorated_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
@staticmethod
|
||||||
|
def helper():
|
||||||
|
return "hi"
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
# Note: extractor uses node.lineno (def line) — decorators above are
|
||||||
|
# excluded by design (matches user-visible symbol start at `def`).
|
||||||
|
assert any(s.name == "helper" for s in spans)
|
||||||
|
span = next(s for s in spans if s.name == "helper")
|
||||||
|
assert span.start_line == 2 # the `def` line
|
||||||
|
|
||||||
|
def test_nested_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
def outer():
|
||||||
|
def inner():
|
||||||
|
return 1
|
||||||
|
return inner()
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "py")
|
||||||
|
names = {s.name for s in spans}
|
||||||
|
assert "outer" in names
|
||||||
|
assert "inner" in names
|
||||||
|
|
||||||
|
def test_empty_file(self):
|
||||||
|
assert self.extractor.extract_symbols("", "py") == []
|
||||||
|
|
||||||
|
def test_no_symbols_in_docstring_only_file(self):
|
||||||
|
content = '"""just a docstring"""\n'
|
||||||
|
assert self.extractor.extract_symbols(content, "py") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegexSymbolExtractor:
|
||||||
|
extractor = RegexSymbolExtractor()
|
||||||
|
|
||||||
|
def test_unsupported_language_returns_empty(self):
|
||||||
|
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
|
||||||
|
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
|
||||||
|
|
||||||
|
def test_typescript_function_declaration(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
export function renderComponent(props: Props): JSX.Element {
|
||||||
|
return <div/>;
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
|
||||||
|
|
||||||
|
def test_typescript_async_function(self):
|
||||||
|
content = "async function fetchData() {\n return await fetch();\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "fetchData" for s in spans)
|
||||||
|
|
||||||
|
def test_typescript_arrow_function_const(self):
|
||||||
|
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "handleClick" for s in spans)
|
||||||
|
|
||||||
|
def test_typescript_class(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
export abstract class BaseService {
|
||||||
|
abstract run(): void;
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "ts")
|
||||||
|
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
|
||||||
|
|
||||||
|
def test_javascript_function(self):
|
||||||
|
content = "function foo() {\n return 1;\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
assert any(s.name == "foo" for s in spans)
|
||||||
|
|
||||||
|
def test_javascript_arrow_const(self):
|
||||||
|
content = "const bar = () => 42;\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
assert any(s.name == "bar" for s in spans)
|
||||||
|
|
||||||
|
def test_go_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
package main
|
||||||
|
|
||||||
|
func HandleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Start() {
|
||||||
|
// method
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "go")
|
||||||
|
names = {s.name for s in spans}
|
||||||
|
assert "HandleRequest" in names
|
||||||
|
assert "Start" in names # method receiver pattern
|
||||||
|
|
||||||
|
def test_go_struct(self):
|
||||||
|
content = "type Server struct {\n Addr string\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "go")
|
||||||
|
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
|
||||||
|
|
||||||
|
def test_rust_function(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
pub fn process(input: &str) -> Result<usize, Error> {
|
||||||
|
Ok(input.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch() -> Bytes {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "rs")
|
||||||
|
names = {s.name for s in spans}
|
||||||
|
assert "process" in names
|
||||||
|
assert "fetch" in names
|
||||||
|
|
||||||
|
def test_rust_struct(self):
|
||||||
|
content = "pub struct Config {\n pub path: String,\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "rs")
|
||||||
|
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
|
||||||
|
|
||||||
|
def test_rust_impl(self):
|
||||||
|
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "rs")
|
||||||
|
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
|
||||||
|
|
||||||
|
def test_java_class(self):
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
package com.example;
|
||||||
|
|
||||||
|
public class UserService {
|
||||||
|
public User findById(long id) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "java")
|
||||||
|
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
|
||||||
|
|
||||||
|
def test_java_method(self):
|
||||||
|
content = "public User findById(long id) {\n return null;\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "java")
|
||||||
|
assert any(s.name == "findById" and s.kind == "function" for s in spans)
|
||||||
|
|
||||||
|
def test_end_line_extends_to_next_symbol(self):
|
||||||
|
# First symbol's end_line is the line before the second symbol starts.
|
||||||
|
content = textwrap.dedent('''
|
||||||
|
function first() {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
function second() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
''').strip()
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
spans.sort(key=lambda s: s.start_line)
|
||||||
|
first = spans[0]
|
||||||
|
second = spans[1]
|
||||||
|
assert first.name == "first"
|
||||||
|
assert second.name == "second"
|
||||||
|
assert first.end_line == second.start_line - 1
|
||||||
|
|
||||||
|
def test_last_symbol_end_line_is_eof(self):
|
||||||
|
content = "function only() {\n return 1;\n}\n"
|
||||||
|
spans = self.extractor.extract_symbols(content, "js")
|
||||||
|
assert len(spans) == 1
|
||||||
|
assert spans[0].end_line == len(content.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# get_extractor + integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetExtractor:
|
||||||
|
def test_python_returns_ast_extractor(self):
|
||||||
|
ext = get_extractor("py")
|
||||||
|
assert ext is not None
|
||||||
|
assert isinstance(ext, AstSymbolExtractor)
|
||||||
|
|
||||||
|
def test_typescript_returns_regex_extractor(self):
|
||||||
|
ext = get_extractor("ts")
|
||||||
|
assert ext is not None
|
||||||
|
assert isinstance(ext, RegexSymbolExtractor)
|
||||||
|
|
||||||
|
def test_unsupported_returns_none(self):
|
||||||
|
assert get_extractor("md") is None
|
||||||
|
assert get_extractor("") is None
|
||||||
|
assert get_extractor("unknown") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractSymbolsFromFile:
|
||||||
|
def test_python_file(self, tmp_path):
|
||||||
|
path = tmp_path / "module.py"
|
||||||
|
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
|
||||||
|
spans, lang = extract_symbols_from_file(path)
|
||||||
|
assert lang == "py"
|
||||||
|
assert any(s.name == "hello" for s in spans)
|
||||||
|
|
||||||
|
def test_unsupported_extension(self, tmp_path):
|
||||||
|
path = tmp_path / "notes.md"
|
||||||
|
path.write_text("# Hello\n", encoding="utf-8")
|
||||||
|
spans, lang = extract_symbols_from_file(path)
|
||||||
|
assert lang == ""
|
||||||
|
assert spans == []
|
||||||
|
|
||||||
|
def test_missing_file_returns_empty(self, tmp_path):
|
||||||
|
path = tmp_path / "nonexistent.py"
|
||||||
|
spans, lang = extract_symbols_from_file(path)
|
||||||
|
# lang is detected from extension even if read fails.
|
||||||
|
assert lang == "py"
|
||||||
|
assert spans == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SymbolSpan dataclass
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSymbolSpan:
|
||||||
|
def test_frozen_dataclass(self):
|
||||||
|
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
|
||||||
|
assert span.name == "foo"
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
span.name = "bar" # type: ignore[misc] — frozen
|
||||||
|
|
||||||
|
def test_equality(self):
|
||||||
|
a = SymbolSpan("foo", "function", 1, 3)
|
||||||
|
b = SymbolSpan("foo", "function", 1, 3)
|
||||||
|
assert a == b
|
||||||
|
assert hash(a) == hash(b)
|
||||||
Loading…
Reference in New Issue