"""Symbol extraction — locate code symbols (functions/classes/structs) by name. KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The `SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor can replace RegexSymbolExtractor behind the same interface. ponytail: regex extractor covers ~80% case (top-level function/class/struct declarations). Ceiling: misses nested signatures inside JSX/TSX generics, multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter. """ from __future__ import annotations import ast import logging import re from dataclasses import dataclass from pathlib import Path from typing import Protocol, runtime_checkable logger = logging.getLogger(__name__) @dataclass(frozen=True, slots=True) class SymbolSpan: """A located symbol — name, kind, and 1-based inclusive line range.""" name: str kind: str # "function" | "class" | "method" | "struct" | "impl" start_line: int # 1-based, inclusive end_line: int # 1-based, inclusive @runtime_checkable class SymbolExtractor(Protocol): """Protocol for symbol extractors — runtime_checkable for isinstance/issubclass.""" def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: """Return all symbols found in `content`. `language` is the file extension without leading dot (e.g. "py", "ts"). Implementations must never raise on extraction failure — return [] on parse errors and let the caller decide the fallback (full-file read). """ ... # --------------------------------------------------------------------------- # Python — stdlib ast # --------------------------------------------------------------------------- class AstSymbolExtractor: """Python symbol extractor using the stdlib `ast` module. Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested functions inside classes. The end_line is the last line of the node's source segment (decorator-inclusive). """ def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: if language != "py": return [] try: tree = ast.parse(content) except SyntaxError as e: logger.debug("ast.parse failed: %s", e) return [] lines = content.splitlines() spans: list[SymbolSpan] = [] for node in ast.walk(tree): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): kind = "method" if _is_method(node) else "function" spans.append(_span_from_node(node, kind, lines)) elif isinstance(node, ast.ClassDef): spans.append(_span_from_node(node, "class", lines)) return spans def _is_method(node: ast.AST) -> bool: """A FunctionDef is a method if its parent is a ClassDef. `ast.walk` doesn't expose parentage, so we approximate by checking the node's col_offset == 4 (indented inside a class body). ponytail: this misses methods in deeply nested classes — ceiling noted; upgrade path = ast.NodeVisitor with parent tracking. """ return getattr(node, "col_offset", 0) > 0 def _span_from_node( node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef, kind: str, lines: list[str], ) -> SymbolSpan: # ast line numbers are 1-based; start at decorator if present (lineno points # to the def/class keyword, decorators are above). Use node.lineno for start # so the returned range matches what the user sees at the def keyword. start = node.lineno # node.end_lineno is the last line of the node body (None on old Pythons). end = node.end_lineno or start # Clamp to actual file length (defensive — ast should not exceed, but # malformed files with no trailing newline can confuse end_lineno). if end > len(lines): end = len(lines) return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end) # --------------------------------------------------------------------------- # Regex extractor — TS/JS/Go/Rust/Java # --------------------------------------------------------------------------- # Each pattern matches a declaration and captures the symbol name in group 1. # Patterns use re.MULTILINE so ^ matches line starts. _REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = { "ts": [ ( "function", re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE), ), ("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)), ( "function", re.compile( r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE, ), ), ], "js": [ ("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)), ("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)), ( "function", re.compile( r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE ), ), ], "go": [ ("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)), ("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)), ], "rs": [ ("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)), ("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)), ("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)), ], "java": [ ( "function", re.compile( r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{", re.MULTILINE, ), ), ( "class", re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE), ), ], } class RegexSymbolExtractor: """Language-aware regex symbol extractor for TS/JS/Go/Rust/Java. Returns SymbolSpans whose end_line is approximated by the next blank line or next-symbol start (whichever comes first). ponytail: this is an approximation — true block-end requires language-aware brace matching. Ceiling: deeply nested blocks may over-extend the range. Upgrade path = tree-sitter. """ def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]: patterns = _REGEX_PATTERNS.get(language) if not patterns: return [] lines = content.splitlines() # Collect (line_no, name, kind) tuples first, then compute end_line # as the line before the next symbol starts (or EOF). raw_hits: list[tuple[int, str, str]] = [] for kind, pattern in patterns: for m in pattern.finditer(content): # Convert match offset to 1-based line number. line_no = content[: m.start()].count("\n") + 1 raw_hits.append((line_no, m.group(1), kind)) if not raw_hits: return [] # Deduplicate: same (line_no, name) may appear for overlapping patterns. seen: set[tuple[int, str]] = set() unique: list[tuple[int, str, str]] = [] for line_no, name, kind in raw_hits: key = (line_no, name) if key in seen: continue seen.add(key) unique.append((line_no, name, kind)) unique.sort(key=lambda x: x[0]) spans: list[SymbolSpan] = [] for i, (start_line, name, kind) in enumerate(unique): if i + 1 < len(unique): # End at line before next symbol starts, capped at file length. end_line = unique[i + 1][0] - 1 else: end_line = len(lines) if end_line < start_line: end_line = start_line spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line)) return spans # --------------------------------------------------------------------------- # Dispatch by file extension # --------------------------------------------------------------------------- _EXTENSION_LANGUAGE: dict[str, str] = { ".py": "py", ".ts": "ts", ".tsx": "ts", ".js": "js", ".jsx": "js", ".mjs": "js", ".cjs": "js", ".go": "go", ".rs": "rs", ".java": "java", } _DEFAULT_EXTRACTOR = AstSymbolExtractor() _REGEX_EXTRACTOR = RegexSymbolExtractor() def language_for_extension(ext: str) -> str: """Return the language key for a file extension (with or without leading dot). Returns "" for unsupported extensions. """ if not ext.startswith("."): ext = "." + ext return _EXTENSION_LANGUAGE.get(ext.lower(), "") def get_extractor(language: str) -> SymbolExtractor | None: """Return the appropriate extractor for `language`, or None if unsupported.""" if language == "py": return _DEFAULT_EXTRACTOR if language in _REGEX_PATTERNS: return _REGEX_EXTRACTOR return None def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]: """Read a file and return (symbols, language). Returns ([], "") if the extension is unsupported or the file cannot be read. Never raises — callers use this for fallback routing. """ ext = path.suffix.lower() language = language_for_extension(ext) if not language: return [], "" try: content = path.read_text(encoding="utf-8", errors="replace") except OSError as e: logger.debug("read failed for %s: %s", path, e) return [], language extractor = get_extractor(language) if extractor is None: return [], language return extractor.extract_symbols(content, language), language