feat(agent): Wave 3 strategic coupling (G5/G6) #6
|
|
@ -18,6 +18,7 @@ from agentkit.tools.memory_tool import MemoryTool
|
|||
from agentkit.tools.web_search import WebSearchTool
|
||||
from agentkit.tools.builtin import RunTestsTool, ToolSearchTool
|
||||
from agentkit.tools.search import ToolSearchIndex
|
||||
from agentkit.tools.file_read import ReadFileTool
|
||||
|
||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||
try:
|
||||
|
|
@ -52,4 +53,5 @@ __all__ = [
|
|||
"OutputParser",
|
||||
"ParsedOutput",
|
||||
"ErrorType",
|
||||
"ReadFileTool",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,262 @@
|
|||
"""ReadFileTool — file reading with optional symbol-level sharding (G5, R22/R23).
|
||||
|
||||
Backward compatible with the pre-existing `_FakeTool` benchmark shape — when
|
||||
`symbol=None`, returns the full file content. When `symbol="foo"`, returns
|
||||
the line range of the first matching symbol via `SymbolExtractor`.
|
||||
|
||||
KTD2 (Wave 3 plan): dedicated tool, does NOT extend ShellTool — keeps the
|
||||
file-reading contract clean and gives the LLM a focused schema.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.symbol_extractor import (
|
||||
SymbolSpan,
|
||||
extract_symbols_from_file,
|
||||
language_for_extension,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Read a file from the filesystem, optionally sliced to a single symbol.
|
||||
|
||||
Tool name `read_file` matches the reserved entry in
|
||||
`core/react.py:_DEFAULT_CORE_TOOLS` (which previously had no real
|
||||
implementation — only `_FakeTool` stubs in `cli/benchmark.py`).
|
||||
|
||||
Backward-compat contract: `symbol=None` returns the full file content,
|
||||
matching the shape `{"path": ...}` that downstream callers (benchmark,
|
||||
phase whitelist) already expect.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "read_file",
|
||||
description: str | None = None,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description
|
||||
or (
|
||||
"Read a file from the filesystem. By default returns the full file "
|
||||
"content. Pass `symbol` (function/class/struct name) to slice to just "
|
||||
"that symbol's line range — saves context when you only need one "
|
||||
"function from a large file. Pass `start_line`/`end_line` for manual "
|
||||
"slicing. If `symbol` is set but not found, returns the available "
|
||||
"symbol names so you can retry."
|
||||
),
|
||||
input_schema=input_schema or self._default_input_schema(),
|
||||
output_schema=output_schema or self._default_output_schema(),
|
||||
version=version,
|
||||
tags=tags or ["io", "file", "read"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _default_input_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read (absolute or relative to cwd).",
|
||||
},
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional: name of a function/class/struct/method to slice to. "
|
||||
"When set, returns only the line range of the first matching "
|
||||
"symbol. Supported languages: py, ts/tsx, js/jsx, go, rs, java."
|
||||
),
|
||||
},
|
||||
"start_line": {
|
||||
"type": "integer",
|
||||
"description": "Optional 1-based start line for manual slicing. Overrides `symbol`.",
|
||||
"minimum": 1,
|
||||
},
|
||||
"end_line": {
|
||||
"type": "integer",
|
||||
"description": "Optional 1-based end line (inclusive) for manual slicing. Overrides `symbol`.",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
"required": ["path"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _default_output_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {"type": "string"},
|
||||
"path": {"type": "string"},
|
||||
"start_line": {"type": "integer"},
|
||||
"end_line": {"type": "integer"},
|
||||
"symbol": {"type": "string"},
|
||||
"available_symbols": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Populated when `symbol` is set but not found.",
|
||||
},
|
||||
"note": {"type": "string"},
|
||||
"is_error": {"type": "boolean"},
|
||||
"error": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> dict[str, Any]:
|
||||
raw_path = kwargs.get("path")
|
||||
if not raw_path:
|
||||
return self._error("`path` is required")
|
||||
|
||||
path = Path(raw_path)
|
||||
if not path.is_absolute():
|
||||
path = path.resolve()
|
||||
|
||||
symbol = kwargs.get("symbol")
|
||||
start_line = kwargs.get("start_line")
|
||||
end_line = kwargs.get("end_line")
|
||||
|
||||
# Validate/sanitize line overrides.
|
||||
if start_line is not None and (not isinstance(start_line, int) or start_line < 1):
|
||||
return self._error(f"`start_line` must be a positive integer, got {start_line!r}")
|
||||
if end_line is not None and (not isinstance(end_line, int) or end_line < 1):
|
||||
return self._error(f"`end_line` must be a positive integer, got {end_line!r}")
|
||||
if start_line is not None and end_line is not None and end_line < start_line:
|
||||
return self._error(f"`end_line` ({end_line}) must be >= `start_line` ({start_line})")
|
||||
|
||||
# Filesystem checks.
|
||||
if not path.exists():
|
||||
return self._error(f"File not found: {path}", path=str(path))
|
||||
if path.is_dir():
|
||||
return self._error(f"Path is a directory, not a file: {path}", path=str(path))
|
||||
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
except PermissionError as e:
|
||||
return self._error(f"Permission denied: {path}", path=str(path), detail=str(e))
|
||||
except OSError as e:
|
||||
return self._error(f"Failed to read {path}: {e}", path=str(path))
|
||||
|
||||
lines = content.splitlines()
|
||||
total_lines = len(lines)
|
||||
|
||||
# Manual slicing takes precedence over symbol (per plan U1 Approach).
|
||||
if start_line is not None or end_line is not None:
|
||||
s = max(1, start_line or 1)
|
||||
e = min(total_lines, end_line or total_lines)
|
||||
sliced = "\n".join(lines[s - 1 : e])
|
||||
return {
|
||||
"content": sliced,
|
||||
"path": str(path),
|
||||
"start_line": s,
|
||||
"end_line": e,
|
||||
"total_lines": total_lines,
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
# Symbol slicing.
|
||||
if symbol:
|
||||
ext = path.suffix.lower()
|
||||
language = language_for_extension(ext)
|
||||
if not language:
|
||||
# Unsupported extension: return full file with note (per plan U1 Edge case).
|
||||
return {
|
||||
"content": content,
|
||||
"path": str(path),
|
||||
"start_line": 1,
|
||||
"end_line": total_lines,
|
||||
"total_lines": total_lines,
|
||||
"note": f"symbol extraction not supported for {ext or 'unknown extension'}",
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
spans, _lang = extract_symbols_from_file(path)
|
||||
# Re-extract using the content we already read so we don't read the file twice.
|
||||
if not spans:
|
||||
# Try extraction from in-memory content (path-based extraction may
|
||||
# have failed silently on OSError; we already read it successfully).
|
||||
from agentkit.tools.symbol_extractor import get_extractor
|
||||
|
||||
extractor = get_extractor(language)
|
||||
if extractor is not None:
|
||||
spans = extractor.extract_symbols(content, language)
|
||||
|
||||
match = _find_symbol(spans, symbol)
|
||||
if match is None:
|
||||
available = sorted({s.name for s in spans})
|
||||
return {
|
||||
"content": "",
|
||||
"path": str(path),
|
||||
"symbol": symbol,
|
||||
"available_symbols": available,
|
||||
"is_error": False,
|
||||
"note": (
|
||||
f"Symbol {symbol!r} not found in {path.name}. "
|
||||
f"Available: {', '.join(available) if available else '(none)'}"
|
||||
),
|
||||
}
|
||||
|
||||
s = match.start_line
|
||||
e = min(match.end_line, total_lines)
|
||||
sliced = "\n".join(lines[s - 1 : e])
|
||||
return {
|
||||
"content": sliced,
|
||||
"path": str(path),
|
||||
"symbol": symbol,
|
||||
"symbol_kind": match.kind,
|
||||
"start_line": s,
|
||||
"end_line": e,
|
||||
"total_lines": total_lines,
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
# Default: full file (characterization baseline — matches _FakeTool shape).
|
||||
return {
|
||||
"content": content,
|
||||
"path": str(path),
|
||||
"start_line": 1,
|
||||
"end_line": total_lines,
|
||||
"total_lines": total_lines,
|
||||
"is_error": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _error(
|
||||
message: str, *, path: str | None = None, detail: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {
|
||||
"content": "",
|
||||
"is_error": True,
|
||||
"error": message,
|
||||
}
|
||||
if path is not None:
|
||||
result["path"] = path
|
||||
if detail is not None:
|
||||
result["detail"] = detail
|
||||
return result
|
||||
|
||||
|
||||
def _find_symbol(spans: list[SymbolSpan], name: str) -> SymbolSpan | None:
|
||||
"""Find the first symbol matching `name`. Case-sensitive.
|
||||
|
||||
ponytail: linear scan is fine for typical file symbol counts (<100). The
|
||||
extractor already returns symbols sorted by start_line; first match wins
|
||||
for ambiguous overloads (e.g., Python classes with same name in different
|
||||
modules — not relevant within one file).
|
||||
"""
|
||||
for span in spans:
|
||||
if span.name == name:
|
||||
return span
|
||||
return None
|
||||
|
|
@ -0,0 +1,278 @@
|
|||
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
|
||||
|
||||
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
|
||||
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
|
||||
`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor
|
||||
can replace RegexSymbolExtractor behind the same interface.
|
||||
|
||||
ponytail: regex extractor covers ~80% case (top-level function/class/struct
|
||||
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
|
||||
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SymbolSpan:
|
||||
"""A located symbol — name, kind, and 1-based inclusive line range."""
|
||||
|
||||
name: str
|
||||
kind: str # "function" | "class" | "method" | "struct" | "impl"
|
||||
start_line: int # 1-based, inclusive
|
||||
end_line: int # 1-based, inclusive
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SymbolExtractor(Protocol):
|
||||
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
|
||||
|
||||
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||
"""Return all symbols found in `content`.
|
||||
|
||||
`language` is the file extension without leading dot (e.g. "py", "ts").
|
||||
Implementations must never raise on extraction failure — return [] on
|
||||
parse errors and let the caller decide the fallback (full-file read).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Python — stdlib ast
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AstSymbolExtractor:
|
||||
"""Python symbol extractor using the stdlib `ast` module.
|
||||
|
||||
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
|
||||
functions inside classes. The end_line is the last line of the node's
|
||||
source segment (decorator-inclusive).
|
||||
"""
|
||||
|
||||
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||
if language != "py":
|
||||
return []
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
except SyntaxError as e:
|
||||
logger.debug("ast.parse failed: %s", e)
|
||||
return []
|
||||
|
||||
lines = content.splitlines()
|
||||
spans: list[SymbolSpan] = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
kind = "method" if _is_method(node) else "function"
|
||||
spans.append(_span_from_node(node, kind, lines))
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
spans.append(_span_from_node(node, "class", lines))
|
||||
return spans
|
||||
|
||||
|
||||
def _is_method(node: ast.AST) -> bool:
|
||||
"""A FunctionDef is a method if its parent is a ClassDef.
|
||||
|
||||
`ast.walk` doesn't expose parentage, so we approximate by checking the
|
||||
node's col_offset == 4 (indented inside a class body). ponytail: this
|
||||
misses methods in deeply nested classes — ceiling noted; upgrade path =
|
||||
ast.NodeVisitor with parent tracking.
|
||||
"""
|
||||
return getattr(node, "col_offset", 0) > 0
|
||||
|
||||
|
||||
def _span_from_node(
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
||||
kind: str,
|
||||
lines: list[str],
|
||||
) -> SymbolSpan:
|
||||
# ast line numbers are 1-based; start at decorator if present (lineno points
|
||||
# to the def/class keyword, decorators are above). Use node.lineno for start
|
||||
# so the returned range matches what the user sees at the def keyword.
|
||||
start = node.lineno
|
||||
# node.end_lineno is the last line of the node body (None on old Pythons).
|
||||
end = node.end_lineno or start
|
||||
# Clamp to actual file length (defensive — ast should not exceed, but
|
||||
# malformed files with no trailing newline can confuse end_lineno).
|
||||
if end > len(lines):
|
||||
end = len(lines)
|
||||
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regex extractor — TS/JS/Go/Rust/Java
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Each pattern matches a declaration and captures the symbol name in group 1.
|
||||
# Patterns use re.MULTILINE so ^ matches line starts.
|
||||
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
|
||||
"ts": [
|
||||
(
|
||||
"function",
|
||||
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
|
||||
),
|
||||
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
|
||||
(
|
||||
"function",
|
||||
re.compile(
|
||||
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
|
||||
re.MULTILINE,
|
||||
),
|
||||
),
|
||||
],
|
||||
"js": [
|
||||
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
|
||||
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
|
||||
(
|
||||
"function",
|
||||
re.compile(
|
||||
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
|
||||
),
|
||||
),
|
||||
],
|
||||
"go": [
|
||||
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
|
||||
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
|
||||
],
|
||||
"rs": [
|
||||
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
|
||||
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
|
||||
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
|
||||
],
|
||||
"java": [
|
||||
(
|
||||
"function",
|
||||
re.compile(
|
||||
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
|
||||
re.MULTILINE,
|
||||
),
|
||||
),
|
||||
(
|
||||
"class",
|
||||
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class RegexSymbolExtractor:
|
||||
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
|
||||
|
||||
Returns SymbolSpans whose end_line is approximated by the next blank line
|
||||
or next-symbol start (whichever comes first). ponytail: this is an
|
||||
approximation — true block-end requires language-aware brace matching.
|
||||
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
|
||||
tree-sitter.
|
||||
"""
|
||||
|
||||
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
||||
patterns = _REGEX_PATTERNS.get(language)
|
||||
if not patterns:
|
||||
return []
|
||||
|
||||
lines = content.splitlines()
|
||||
# Collect (line_no, name, kind) tuples first, then compute end_line
|
||||
# as the line before the next symbol starts (or EOF).
|
||||
raw_hits: list[tuple[int, str, str]] = []
|
||||
for kind, pattern in patterns:
|
||||
for m in pattern.finditer(content):
|
||||
# Convert match offset to 1-based line number.
|
||||
line_no = content[: m.start()].count("\n") + 1
|
||||
raw_hits.append((line_no, m.group(1), kind))
|
||||
|
||||
if not raw_hits:
|
||||
return []
|
||||
|
||||
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
|
||||
seen: set[tuple[int, str]] = set()
|
||||
unique: list[tuple[int, str, str]] = []
|
||||
for line_no, name, kind in raw_hits:
|
||||
key = (line_no, name)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
unique.append((line_no, name, kind))
|
||||
|
||||
unique.sort(key=lambda x: x[0])
|
||||
|
||||
spans: list[SymbolSpan] = []
|
||||
for i, (start_line, name, kind) in enumerate(unique):
|
||||
if i + 1 < len(unique):
|
||||
# End at line before next symbol starts, capped at file length.
|
||||
end_line = unique[i + 1][0] - 1
|
||||
else:
|
||||
end_line = len(lines)
|
||||
if end_line < start_line:
|
||||
end_line = start_line
|
||||
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
|
||||
return spans
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch by file extension
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EXTENSION_LANGUAGE: dict[str, str] = {
|
||||
".py": "py",
|
||||
".ts": "ts",
|
||||
".tsx": "ts",
|
||||
".js": "js",
|
||||
".jsx": "js",
|
||||
".mjs": "js",
|
||||
".cjs": "js",
|
||||
".go": "go",
|
||||
".rs": "rs",
|
||||
".java": "java",
|
||||
}
|
||||
|
||||
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
|
||||
_REGEX_EXTRACTOR = RegexSymbolExtractor()
|
||||
|
||||
|
||||
def language_for_extension(ext: str) -> str:
|
||||
"""Return the language key for a file extension (with or without leading dot).
|
||||
|
||||
Returns "" for unsupported extensions.
|
||||
"""
|
||||
if not ext.startswith("."):
|
||||
ext = "." + ext
|
||||
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
|
||||
|
||||
|
||||
def get_extractor(language: str) -> SymbolExtractor | None:
|
||||
"""Return the appropriate extractor for `language`, or None if unsupported."""
|
||||
if language == "py":
|
||||
return _DEFAULT_EXTRACTOR
|
||||
if language in _REGEX_PATTERNS:
|
||||
return _REGEX_EXTRACTOR
|
||||
return None
|
||||
|
||||
|
||||
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
|
||||
"""Read a file and return (symbols, language).
|
||||
|
||||
Returns ([], "") if the extension is unsupported or the file cannot be read.
|
||||
Never raises — callers use this for fallback routing.
|
||||
"""
|
||||
ext = path.suffix.lower()
|
||||
language = language_for_extension(ext)
|
||||
if not language:
|
||||
return [], ""
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8", errors="replace")
|
||||
except OSError as e:
|
||||
logger.debug("read failed for %s: %s", path, e)
|
||||
return [], language
|
||||
extractor = get_extractor(language)
|
||||
if extractor is None:
|
||||
return [], language
|
||||
return extractor.extract_symbols(content, language), language
|
||||
|
|
@ -0,0 +1,367 @@
|
|||
"""Unit tests for ReadFileTool — G5 (R22, R23) + characterization baseline.
|
||||
|
||||
Per plan U1 Execution note: characterization-first — assert that
|
||||
`symbol=None` returns the full file content (matches pre-existing benchmark
|
||||
`_FakeTool` shape) before adding symbol-extraction behavior.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.file_read import ReadFileTool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFileToolSchema:
|
||||
def test_name_is_read_file(self):
|
||||
tool = ReadFileTool()
|
||||
assert tool.name == "read_file"
|
||||
|
||||
def test_required_path(self):
|
||||
tool = ReadFileTool()
|
||||
assert "path" in tool.input_schema["required"]
|
||||
assert "path" in tool.input_schema["properties"]
|
||||
|
||||
def test_optional_symbol_and_lines(self):
|
||||
tool = ReadFileTool()
|
||||
props = tool.input_schema["properties"]
|
||||
assert "symbol" in props
|
||||
assert "start_line" in props
|
||||
assert "end_line" in props
|
||||
# None of the optional fields should be in `required`.
|
||||
required = set(tool.input_schema["required"])
|
||||
assert required == {"path"}
|
||||
|
||||
def test_additional_properties_false(self):
|
||||
# LLM tool-call schemas should reject unknown args (Wave 1 U3 pattern).
|
||||
tool = ReadFileTool()
|
||||
assert tool.input_schema.get("additionalProperties") is False
|
||||
|
||||
def test_tags_contain_io_and_read(self):
|
||||
tool = ReadFileTool()
|
||||
assert "io" in tool.tags
|
||||
assert "read" in tool.tags
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Characterization — symbol=None returns full file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_py_file(tmp_path):
|
||||
path = tmp_path / "sample.py"
|
||||
path.write_text(
|
||||
textwrap.dedent('''
|
||||
"""Sample module."""
|
||||
|
||||
def my_func():
|
||||
return 42
|
||||
|
||||
|
||||
class MyClass:
|
||||
attr = 1
|
||||
|
||||
def method_a(self):
|
||||
return self.attr
|
||||
''').lstrip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ts_file(tmp_path):
|
||||
path = tmp_path / "sample.ts"
|
||||
path.write_text(
|
||||
textwrap.dedent('''
|
||||
export function renderComponent(): JSX.Element {
|
||||
return <div/>;
|
||||
}
|
||||
|
||||
export class BaseService {
|
||||
abstract run(): void;
|
||||
}
|
||||
''').lstrip(),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
class TestCharacterizationFullFile:
|
||||
"""symbol=None returns the whole file (matches _FakeTool baseline)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_file_returned_when_symbol_none(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file))
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["path"] == str(sample_py_file)
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == result["total_lines"]
|
||||
assert "def my_func" in result["content"]
|
||||
assert "class MyClass" in result["content"]
|
||||
assert result["content"].startswith('"""Sample module."""')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_file_includes_all_lines(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file))
|
||||
assert result["total_lines"] >= 8
|
||||
assert result["content"].count("\n") >= result["total_lines"] - 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Symbol slicing — happy paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSymbolSlicing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_python_function(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="my_func")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "my_func"
|
||||
assert result["symbol_kind"] == "function"
|
||||
assert "def my_func" in result["content"]
|
||||
assert "return 42" in result["content"]
|
||||
# Should NOT include the class below.
|
||||
assert "class MyClass" not in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_python_class_includes_method(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="MyClass")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "MyClass"
|
||||
assert result["symbol_kind"] == "class"
|
||||
assert "class MyClass" in result["content"]
|
||||
assert "def method_a" in result["content"] # method included
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_python_method_directly(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="method_a")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "method_a"
|
||||
assert result["symbol_kind"] == "method"
|
||||
assert "def method_a" in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typescript_function(self, sample_ts_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_ts_file), symbol="renderComponent")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "renderComponent"
|
||||
assert "renderComponent" in result["content"]
|
||||
# Should not include the class below.
|
||||
assert "BaseService" not in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typescript_class(self, sample_ts_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_ts_file), symbol="BaseService")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["symbol"] == "BaseService"
|
||||
assert result["symbol_kind"] == "class"
|
||||
assert "BaseService" in result["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Symbol slicing — edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSymbolSlicingEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_symbol_not_found_lists_available(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), symbol="nonexistent")
|
||||
|
||||
assert result["is_error"] is False # soft error, not hard
|
||||
assert result["content"] == ""
|
||||
assert result["symbol"] == "nonexistent"
|
||||
available = result["available_symbols"]
|
||||
assert "my_func" in available
|
||||
assert "MyClass" in available
|
||||
assert "method_a" in available
|
||||
assert "nonexistent" not in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_extension_returns_full_with_note(self, tmp_path):
|
||||
path = tmp_path / "notes.md"
|
||||
path.write_text("# Hello\nworld\n", encoding="utf-8")
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(path), symbol="anything")
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["content"] == "# Hello\nworld\n"
|
||||
assert "symbol extraction not supported" in result["note"]
|
||||
assert ".md" in result["note"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_file(self, tmp_path):
|
||||
path = tmp_path / "empty.py"
|
||||
path.write_text("", encoding="utf-8")
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(path))
|
||||
|
||||
assert result["is_error"] is False
|
||||
assert result["content"] == ""
|
||||
assert result["total_lines"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_with_no_symbols(self, tmp_path):
|
||||
path = tmp_path / "data.py"
|
||||
path.write_text("# just a comment\nPI = 3.14\n", encoding="utf-8")
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(path), symbol="PI")
|
||||
|
||||
# PI is not a def/class — extractor finds no symbols; soft error lists available.
|
||||
assert result["is_error"] is False
|
||||
assert result["content"] == ""
|
||||
assert result["available_symbols"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReadFileToolErrors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_required(self):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute()
|
||||
assert result["is_error"] is True
|
||||
assert "path" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_empty_string(self):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path="")
|
||||
assert result["is_error"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_not_found(self, tmp_path):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(tmp_path / "missing.py"))
|
||||
assert result["is_error"] is True
|
||||
assert "not found" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_is_directory(self, tmp_path):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(tmp_path))
|
||||
assert result["is_error"] is True
|
||||
assert "directory" in result["error"].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manual line slicing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestManualLineSlicing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_and_end_line(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(
|
||||
path=str(sample_py_file),
|
||||
start_line=3,
|
||||
end_line=5,
|
||||
)
|
||||
assert result["is_error"] is False
|
||||
assert result["start_line"] == 3
|
||||
assert result["end_line"] == 5
|
||||
# Lines 3-5 of the sample file:
|
||||
# line 3: "def my_func():"
|
||||
# line 4: " return 42"
|
||||
# line 5: "" (blank)
|
||||
assert "def my_func" in result["content"]
|
||||
assert "return 42" in result["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_line_only_extends_to_eof(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), start_line=8)
|
||||
assert result["is_error"] is False
|
||||
assert result["start_line"] == 8
|
||||
assert result["end_line"] == result["total_lines"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_line_only_starts_at_one(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), end_line=2)
|
||||
assert result["is_error"] is False
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_start_line_zero(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(path=str(sample_py_file), start_line=0)
|
||||
assert result["is_error"] is True
|
||||
assert "start_line" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_before_start(self, sample_py_file):
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(
|
||||
path=str(sample_py_file),
|
||||
start_line=5,
|
||||
end_line=3,
|
||||
)
|
||||
assert result["is_error"] is True
|
||||
assert "end_line" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manual_lines_override_symbol(self, sample_py_file):
|
||||
# Per plan U1 Approach: "start_line/end_line overrides symbol".
|
||||
tool = ReadFileTool()
|
||||
result = await tool.execute(
|
||||
path=str(sample_py_file),
|
||||
symbol="my_func",
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
assert result["is_error"] is False
|
||||
# Manual slicing won — symbol field absent.
|
||||
assert "symbol" not in result or result.get("symbol") is None
|
||||
assert result["start_line"] == 1
|
||||
assert result["end_line"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration — tool registry discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolRegistryDiscovery:
|
||||
def test_instantiable_without_args(self):
|
||||
# Default constructor — matches the convention used by ToolRegistry
|
||||
# to instantiate tools by class.
|
||||
tool = ReadFileTool()
|
||||
assert tool.name == "read_file"
|
||||
|
||||
def test_to_dict_serializable(self):
|
||||
tool = ReadFileTool()
|
||||
d = tool.to_dict()
|
||||
assert d["name"] == "read_file"
|
||||
assert "input_schema" in d
|
||||
assert "output_schema" in d
|
||||
assert d["tags"] == ["io", "file", "read"]
|
||||
|
|
@ -0,0 +1,359 @@
|
|||
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
|
||||
|
||||
Covers R22 (file reading supports symbol/function granularity) and KTD1
|
||||
(Python ast + language-aware regex, no tree-sitter dependency).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.symbol_extractor import (
|
||||
AstSymbolExtractor,
|
||||
RegexSymbolExtractor,
|
||||
SymbolSpan,
|
||||
extract_symbols_from_file,
|
||||
get_extractor,
|
||||
language_for_extension,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# language_for_extension
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLanguageForExtension:
|
||||
def test_python_extensions(self):
|
||||
assert language_for_extension("py") == "py"
|
||||
assert language_for_extension(".py") == "py"
|
||||
assert language_for_extension(".PY") == "py" # case-insensitive
|
||||
|
||||
def test_typescript_javascript(self):
|
||||
assert language_for_extension(".ts") == "ts"
|
||||
assert language_for_extension(".tsx") == "ts"
|
||||
assert language_for_extension(".js") == "js"
|
||||
assert language_for_extension(".jsx") == "js"
|
||||
assert language_for_extension(".mjs") == "js"
|
||||
assert language_for_extension(".cjs") == "js"
|
||||
|
||||
def test_go_rust_java(self):
|
||||
assert language_for_extension(".go") == "go"
|
||||
assert language_for_extension(".rs") == "rs"
|
||||
assert language_for_extension(".java") == "java"
|
||||
|
||||
def test_unsupported_returns_empty(self):
|
||||
assert language_for_extension(".md") == ""
|
||||
assert language_for_extension(".txt") == ""
|
||||
assert language_for_extension("") == ""
|
||||
assert language_for_extension(".unknown") == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AstSymbolExtractor — Python
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAstSymbolExtractor:
|
||||
extractor = AstSymbolExtractor()
|
||||
|
||||
def test_unsupported_language_returns_empty(self):
|
||||
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
|
||||
|
||||
def test_syntax_error_returns_empty(self):
|
||||
# Never raises — callers rely on this for fallback routing.
|
||||
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
|
||||
assert result == []
|
||||
|
||||
def test_top_level_function(self):
|
||||
content = "def my_func():\n return 42\n"
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "my_func"
|
||||
assert span.kind == "function"
|
||||
assert span.start_line == 1
|
||||
assert span.end_line == 2
|
||||
|
||||
def test_async_function(self):
|
||||
content = "async def fetch():\n return 1\n"
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
assert len(spans) == 1
|
||||
assert spans[0].name == "fetch"
|
||||
assert spans[0].kind == "function"
|
||||
|
||||
def test_top_level_class(self):
|
||||
content = textwrap.dedent('''
|
||||
class MyClass:
|
||||
"""docstring"""
|
||||
|
||||
def method_a(self):
|
||||
return 1
|
||||
|
||||
async def method_b(self):
|
||||
return 2
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
names = [s.name for s in spans]
|
||||
assert "MyClass" in names
|
||||
assert "method_a" in names
|
||||
assert "method_b" in names
|
||||
|
||||
cls = next(s for s in spans if s.name == "MyClass")
|
||||
assert cls.kind == "class"
|
||||
assert cls.start_line == 1
|
||||
# Class body extends through the last method's end_lineno.
|
||||
assert cls.end_line >= 7
|
||||
|
||||
def test_methods_classified_as_methods(self):
|
||||
content = textwrap.dedent('''
|
||||
class Foo:
|
||||
def bar(self):
|
||||
pass
|
||||
|
||||
def top_level():
|
||||
pass
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
by_name = {s.name: s for s in spans}
|
||||
assert by_name["bar"].kind == "method"
|
||||
assert by_name["top_level"].kind == "function"
|
||||
|
||||
def test_decorated_function(self):
|
||||
content = textwrap.dedent('''
|
||||
@staticmethod
|
||||
def helper():
|
||||
return "hi"
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
# Note: extractor uses node.lineno (def line) — decorators above are
|
||||
# excluded by design (matches user-visible symbol start at `def`).
|
||||
assert any(s.name == "helper" for s in spans)
|
||||
span = next(s for s in spans if s.name == "helper")
|
||||
assert span.start_line == 2 # the `def` line
|
||||
|
||||
def test_nested_function(self):
|
||||
content = textwrap.dedent('''
|
||||
def outer():
|
||||
def inner():
|
||||
return 1
|
||||
return inner()
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "py")
|
||||
names = {s.name for s in spans}
|
||||
assert "outer" in names
|
||||
assert "inner" in names
|
||||
|
||||
def test_empty_file(self):
|
||||
assert self.extractor.extract_symbols("", "py") == []
|
||||
|
||||
def test_no_symbols_in_docstring_only_file(self):
|
||||
content = '"""just a docstring"""\n'
|
||||
assert self.extractor.extract_symbols(content, "py") == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegexSymbolExtractor:
|
||||
extractor = RegexSymbolExtractor()
|
||||
|
||||
def test_unsupported_language_returns_empty(self):
|
||||
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
|
||||
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
|
||||
|
||||
def test_typescript_function_declaration(self):
|
||||
content = textwrap.dedent('''
|
||||
export function renderComponent(props: Props): JSX.Element {
|
||||
return <div/>;
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
|
||||
|
||||
def test_typescript_async_function(self):
|
||||
content = "async function fetchData() {\n return await fetch();\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "fetchData" for s in spans)
|
||||
|
||||
def test_typescript_arrow_function_const(self):
|
||||
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "handleClick" for s in spans)
|
||||
|
||||
def test_typescript_class(self):
|
||||
content = textwrap.dedent('''
|
||||
export abstract class BaseService {
|
||||
abstract run(): void;
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "ts")
|
||||
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
|
||||
|
||||
def test_javascript_function(self):
|
||||
content = "function foo() {\n return 1;\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
assert any(s.name == "foo" for s in spans)
|
||||
|
||||
def test_javascript_arrow_const(self):
|
||||
content = "const bar = () => 42;\n"
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
assert any(s.name == "bar" for s in spans)
|
||||
|
||||
def test_go_function(self):
|
||||
content = textwrap.dedent('''
|
||||
package main
|
||||
|
||||
func HandleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
|
||||
func (s *Server) Start() {
|
||||
// method
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "go")
|
||||
names = {s.name for s in spans}
|
||||
assert "HandleRequest" in names
|
||||
assert "Start" in names # method receiver pattern
|
||||
|
||||
def test_go_struct(self):
|
||||
content = "type Server struct {\n Addr string\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "go")
|
||||
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
|
||||
|
||||
def test_rust_function(self):
|
||||
content = textwrap.dedent('''
|
||||
pub fn process(input: &str) -> Result<usize, Error> {
|
||||
Ok(input.len())
|
||||
}
|
||||
|
||||
async fn fetch() -> Bytes {
|
||||
unimplemented!()
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "rs")
|
||||
names = {s.name for s in spans}
|
||||
assert "process" in names
|
||||
assert "fetch" in names
|
||||
|
||||
def test_rust_struct(self):
|
||||
content = "pub struct Config {\n pub path: String,\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "rs")
|
||||
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
|
||||
|
||||
def test_rust_impl(self):
|
||||
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "rs")
|
||||
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
|
||||
|
||||
def test_java_class(self):
|
||||
content = textwrap.dedent('''
|
||||
package com.example;
|
||||
|
||||
public class UserService {
|
||||
public User findById(long id) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "java")
|
||||
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
|
||||
|
||||
def test_java_method(self):
|
||||
content = "public User findById(long id) {\n return null;\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "java")
|
||||
assert any(s.name == "findById" and s.kind == "function" for s in spans)
|
||||
|
||||
def test_end_line_extends_to_next_symbol(self):
|
||||
# First symbol's end_line is the line before the second symbol starts.
|
||||
content = textwrap.dedent('''
|
||||
function first() {
|
||||
return 1;
|
||||
}
|
||||
|
||||
function second() {
|
||||
return 2;
|
||||
}
|
||||
''').strip()
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
spans.sort(key=lambda s: s.start_line)
|
||||
first = spans[0]
|
||||
second = spans[1]
|
||||
assert first.name == "first"
|
||||
assert second.name == "second"
|
||||
assert first.end_line == second.start_line - 1
|
||||
|
||||
def test_last_symbol_end_line_is_eof(self):
|
||||
content = "function only() {\n return 1;\n}\n"
|
||||
spans = self.extractor.extract_symbols(content, "js")
|
||||
assert len(spans) == 1
|
||||
assert spans[0].end_line == len(content.splitlines())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_extractor + integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetExtractor:
|
||||
def test_python_returns_ast_extractor(self):
|
||||
ext = get_extractor("py")
|
||||
assert ext is not None
|
||||
assert isinstance(ext, AstSymbolExtractor)
|
||||
|
||||
def test_typescript_returns_regex_extractor(self):
|
||||
ext = get_extractor("ts")
|
||||
assert ext is not None
|
||||
assert isinstance(ext, RegexSymbolExtractor)
|
||||
|
||||
def test_unsupported_returns_none(self):
|
||||
assert get_extractor("md") is None
|
||||
assert get_extractor("") is None
|
||||
assert get_extractor("unknown") is None
|
||||
|
||||
|
||||
class TestExtractSymbolsFromFile:
|
||||
def test_python_file(self, tmp_path):
|
||||
path = tmp_path / "module.py"
|
||||
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
|
||||
spans, lang = extract_symbols_from_file(path)
|
||||
assert lang == "py"
|
||||
assert any(s.name == "hello" for s in spans)
|
||||
|
||||
def test_unsupported_extension(self, tmp_path):
|
||||
path = tmp_path / "notes.md"
|
||||
path.write_text("# Hello\n", encoding="utf-8")
|
||||
spans, lang = extract_symbols_from_file(path)
|
||||
assert lang == ""
|
||||
assert spans == []
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
path = tmp_path / "nonexistent.py"
|
||||
spans, lang = extract_symbols_from_file(path)
|
||||
# lang is detected from extension even if read fails.
|
||||
assert lang == "py"
|
||||
assert spans == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SymbolSpan dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSymbolSpan:
|
||||
def test_frozen_dataclass(self):
|
||||
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
|
||||
assert span.name == "foo"
|
||||
with pytest.raises(Exception):
|
||||
span.name = "bar" # type: ignore[misc] — frozen
|
||||
|
||||
def test_equality(self):
|
||||
a = SymbolSpan("foo", "function", 1, 3)
|
||||
b = SymbolSpan("foo", "function", 1, 3)
|
||||
assert a == b
|
||||
assert hash(a) == hash(b)
|
||||
Loading…
Reference in New Issue