279 lines
9.9 KiB
Python
279 lines
9.9 KiB
Python
"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
|
|
|
|
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
|
|
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
|
|
`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor
|
|
can replace RegexSymbolExtractor behind the same interface.
|
|
|
|
ponytail: regex extractor covers ~80% case (top-level function/class/struct
|
|
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
|
|
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Protocol, runtime_checkable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class SymbolSpan:
|
|
"""A located symbol — name, kind, and 1-based inclusive line range."""
|
|
|
|
name: str
|
|
kind: str # "function" | "class" | "method" | "struct" | "impl"
|
|
start_line: int # 1-based, inclusive
|
|
end_line: int # 1-based, inclusive
|
|
|
|
|
|
@runtime_checkable
|
|
class SymbolExtractor(Protocol):
|
|
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
|
|
|
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
|
"""Return all symbols found in `content`.
|
|
|
|
`language` is the file extension without leading dot (e.g. "py", "ts").
|
|
Implementations must never raise on extraction failure — return [] on
|
|
parse errors and let the caller decide the fallback (full-file read).
|
|
"""
|
|
...
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Python — stdlib ast
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class AstSymbolExtractor:
|
|
"""Python symbol extractor using the stdlib `ast` module.
|
|
|
|
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
|
|
functions inside classes. The end_line is the last line of the node's
|
|
source segment (decorator-inclusive).
|
|
"""
|
|
|
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
|
if language != "py":
|
|
return []
|
|
try:
|
|
tree = ast.parse(content)
|
|
except SyntaxError as e:
|
|
logger.debug("ast.parse failed: %s", e)
|
|
return []
|
|
|
|
lines = content.splitlines()
|
|
spans: list[SymbolSpan] = []
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
kind = "method" if _is_method(node) else "function"
|
|
spans.append(_span_from_node(node, kind, lines))
|
|
elif isinstance(node, ast.ClassDef):
|
|
spans.append(_span_from_node(node, "class", lines))
|
|
return spans
|
|
|
|
|
|
def _is_method(node: ast.AST) -> bool:
|
|
"""A FunctionDef is a method if its parent is a ClassDef.
|
|
|
|
`ast.walk` doesn't expose parentage, so we approximate by checking the
|
|
node's col_offset == 4 (indented inside a class body). ponytail: this
|
|
misses methods in deeply nested classes — ceiling noted; upgrade path =
|
|
ast.NodeVisitor with parent tracking.
|
|
"""
|
|
return getattr(node, "col_offset", 0) > 0
|
|
|
|
|
|
def _span_from_node(
|
|
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
|
|
kind: str,
|
|
lines: list[str],
|
|
) -> SymbolSpan:
|
|
# ast line numbers are 1-based; start at decorator if present (lineno points
|
|
# to the def/class keyword, decorators are above). Use node.lineno for start
|
|
# so the returned range matches what the user sees at the def keyword.
|
|
start = node.lineno
|
|
# node.end_lineno is the last line of the node body (None on old Pythons).
|
|
end = node.end_lineno or start
|
|
# Clamp to actual file length (defensive — ast should not exceed, but
|
|
# malformed files with no trailing newline can confuse end_lineno).
|
|
if end > len(lines):
|
|
end = len(lines)
|
|
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Regex extractor — TS/JS/Go/Rust/Java
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Each pattern matches a declaration and captures the symbol name in group 1.
|
|
# Patterns use re.MULTILINE so ^ matches line starts.
|
|
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
|
|
"ts": [
|
|
(
|
|
"function",
|
|
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
|
|
),
|
|
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
|
|
(
|
|
"function",
|
|
re.compile(
|
|
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
|
|
re.MULTILINE,
|
|
),
|
|
),
|
|
],
|
|
"js": [
|
|
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
|
|
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
|
|
(
|
|
"function",
|
|
re.compile(
|
|
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
|
|
),
|
|
),
|
|
],
|
|
"go": [
|
|
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
|
|
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
|
|
],
|
|
"rs": [
|
|
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
|
|
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
|
|
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
|
|
],
|
|
"java": [
|
|
(
|
|
"function",
|
|
re.compile(
|
|
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
|
|
re.MULTILINE,
|
|
),
|
|
),
|
|
(
|
|
"class",
|
|
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
|
|
),
|
|
],
|
|
}
|
|
|
|
|
|
class RegexSymbolExtractor:
|
|
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
|
|
|
|
Returns SymbolSpans whose end_line is approximated by the next blank line
|
|
or next-symbol start (whichever comes first). ponytail: this is an
|
|
approximation — true block-end requires language-aware brace matching.
|
|
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
|
|
tree-sitter.
|
|
"""
|
|
|
|
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
|
|
patterns = _REGEX_PATTERNS.get(language)
|
|
if not patterns:
|
|
return []
|
|
|
|
lines = content.splitlines()
|
|
# Collect (line_no, name, kind) tuples first, then compute end_line
|
|
# as the line before the next symbol starts (or EOF).
|
|
raw_hits: list[tuple[int, str, str]] = []
|
|
for kind, pattern in patterns:
|
|
for m in pattern.finditer(content):
|
|
# Convert match offset to 1-based line number.
|
|
line_no = content[: m.start()].count("\n") + 1
|
|
raw_hits.append((line_no, m.group(1), kind))
|
|
|
|
if not raw_hits:
|
|
return []
|
|
|
|
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
|
|
seen: set[tuple[int, str]] = set()
|
|
unique: list[tuple[int, str, str]] = []
|
|
for line_no, name, kind in raw_hits:
|
|
key = (line_no, name)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
unique.append((line_no, name, kind))
|
|
|
|
unique.sort(key=lambda x: x[0])
|
|
|
|
spans: list[SymbolSpan] = []
|
|
for i, (start_line, name, kind) in enumerate(unique):
|
|
if i + 1 < len(unique):
|
|
# End at line before next symbol starts, capped at file length.
|
|
end_line = unique[i + 1][0] - 1
|
|
else:
|
|
end_line = len(lines)
|
|
if end_line < start_line:
|
|
end_line = start_line
|
|
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
|
|
return spans
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dispatch by file extension
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_EXTENSION_LANGUAGE: dict[str, str] = {
|
|
".py": "py",
|
|
".ts": "ts",
|
|
".tsx": "ts",
|
|
".js": "js",
|
|
".jsx": "js",
|
|
".mjs": "js",
|
|
".cjs": "js",
|
|
".go": "go",
|
|
".rs": "rs",
|
|
".java": "java",
|
|
}
|
|
|
|
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
|
|
_REGEX_EXTRACTOR = RegexSymbolExtractor()
|
|
|
|
|
|
def language_for_extension(ext: str) -> str:
|
|
"""Return the language key for a file extension (with or without leading dot).
|
|
|
|
Returns "" for unsupported extensions.
|
|
"""
|
|
if not ext.startswith("."):
|
|
ext = "." + ext
|
|
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
|
|
|
|
|
|
def get_extractor(language: str) -> SymbolExtractor | None:
|
|
"""Return the appropriate extractor for `language`, or None if unsupported."""
|
|
if language == "py":
|
|
return _DEFAULT_EXTRACTOR
|
|
if language in _REGEX_PATTERNS:
|
|
return _REGEX_EXTRACTOR
|
|
return None
|
|
|
|
|
|
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
|
|
"""Read a file and return (symbols, language).
|
|
|
|
Returns ([], "") if the extension is unsupported or the file cannot be read.
|
|
Never raises — callers use this for fallback routing.
|
|
"""
|
|
ext = path.suffix.lower()
|
|
language = language_for_extension(ext)
|
|
if not language:
|
|
return [], ""
|
|
try:
|
|
content = path.read_text(encoding="utf-8", errors="replace")
|
|
except OSError as e:
|
|
logger.debug("read failed for %s: %s", path, e)
|
|
return [], language
|
|
extractor = get_extractor(language)
|
|
if extractor is None:
|
|
return [], language
|
|
return extractor.extract_symbols(content, language), language
|