fischer-agentkit/src/agentkit/tools/symbol_extractor.py

279 lines
9.9 KiB
Python

"""Symbol extraction — locate code symbols (functions/classes/structs) by name.
KTD1 (Wave 3 plan): Python `ast` (stdlib) for .py files; language-aware regex
for TS/JS/Go/Rust/Java. Avoids tree-sitter native dependency. The
`SymbolExtractor` protocol is the upgrade seam — a future TreeSitterSymbolExtractor
can replace RegexSymbolExtractor behind the same interface.
ponytail: regex extractor covers ~80% case (top-level function/class/struct
declarations). Ceiling: misses nested signatures inside JSX/TSX generics,
multi-line decorator chains, and macro-generated defs. Upgrade path = tree-sitter.
"""
from __future__ import annotations
import ast
import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@dataclass(frozen=True, slots=True)
class SymbolSpan:
"""A located symbol — name, kind, and 1-based inclusive line range."""
name: str
kind: str # "function" | "class" | "method" | "struct" | "impl"
start_line: int # 1-based, inclusive
end_line: int # 1-based, inclusive
@runtime_checkable
class SymbolExtractor(Protocol):
"""Protocol for symbol extractors — runtime_checkable for isinstance/issubclass."""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
"""Return all symbols found in `content`.
`language` is the file extension without leading dot (e.g. "py", "ts").
Implementations must never raise on extraction failure — return [] on
parse errors and let the caller decide the fallback (full-file read).
"""
...
# ---------------------------------------------------------------------------
# Python — stdlib ast
# ---------------------------------------------------------------------------
class AstSymbolExtractor:
"""Python symbol extractor using the stdlib `ast` module.
Captures top-level FunctionDef/AsyncFunctionDef/ClassDef and methods/nested
functions inside classes. The end_line is the last line of the node's
source segment (decorator-inclusive).
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
if language != "py":
return []
try:
tree = ast.parse(content)
except SyntaxError as e:
logger.debug("ast.parse failed: %s", e)
return []
lines = content.splitlines()
spans: list[SymbolSpan] = []
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
kind = "method" if _is_method(node) else "function"
spans.append(_span_from_node(node, kind, lines))
elif isinstance(node, ast.ClassDef):
spans.append(_span_from_node(node, "class", lines))
return spans
def _is_method(node: ast.AST) -> bool:
"""A FunctionDef is a method if its parent is a ClassDef.
`ast.walk` doesn't expose parentage, so we approximate by checking the
node's col_offset == 4 (indented inside a class body). ponytail: this
misses methods in deeply nested classes — ceiling noted; upgrade path =
ast.NodeVisitor with parent tracking.
"""
return getattr(node, "col_offset", 0) > 0
def _span_from_node(
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef,
kind: str,
lines: list[str],
) -> SymbolSpan:
# ast line numbers are 1-based; start at decorator if present (lineno points
# to the def/class keyword, decorators are above). Use node.lineno for start
# so the returned range matches what the user sees at the def keyword.
start = node.lineno
# node.end_lineno is the last line of the node body (None on old Pythons).
end = node.end_lineno or start
# Clamp to actual file length (defensive — ast should not exceed, but
# malformed files with no trailing newline can confuse end_lineno).
if end > len(lines):
end = len(lines)
return SymbolSpan(name=node.name, kind=kind, start_line=start, end_line=end)
# ---------------------------------------------------------------------------
# Regex extractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
# Each pattern matches a declaration and captures the symbol name in group 1.
# Patterns use re.MULTILINE so ^ matches line starts.
_REGEX_PATTERNS: dict[str, list[tuple[str, re.Pattern[str]]]] = {
"ts": [
(
"function",
re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE),
),
("class", re.compile(r"^\s*(?:export\s+)?(?:abstract\s+)?class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>",
re.MULTILINE,
),
),
],
"js": [
("function", re.compile(r"^\s*(?:async\s+)?function\s+(\w+)\s*\(", re.MULTILINE)),
("class", re.compile(r"^\s*class\s+(\w+)\b", re.MULTILINE)),
(
"function",
re.compile(
r"^\s*(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s*)?\(?.*?\)?\s*=>", re.MULTILINE
),
),
],
"go": [
("function", re.compile(r"^func\s+(?:\([^)]*\)\s+)?(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^type\s+(\w+)\s+struct\b", re.MULTILINE)),
],
"rs": [
("function", re.compile(r"^\s*(?:pub\s+)?(?:async\s+)?fn\s+(\w+)\s*\(", re.MULTILINE)),
("struct", re.compile(r"^\s*(?:pub\s+)?struct\s+(\w+)\b", re.MULTILINE)),
("impl", re.compile(r"^impl\b.*?\s+(\w+)\s*\{", re.MULTILINE)),
],
"java": [
(
"function",
re.compile(
r"^\s*(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:\w+(?:<[^>]*>)?)\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+\w+(?:\s*,\s*\w+)*)?\s*\{",
re.MULTILINE,
),
),
(
"class",
re.compile(r"^\s*(?:public\s+)?(?:abstract\s+|final\s+)?class\s+(\w+)\b", re.MULTILINE),
),
],
}
class RegexSymbolExtractor:
"""Language-aware regex symbol extractor for TS/JS/Go/Rust/Java.
Returns SymbolSpans whose end_line is approximated by the next blank line
or next-symbol start (whichever comes first). ponytail: this is an
approximation — true block-end requires language-aware brace matching.
Ceiling: deeply nested blocks may over-extend the range. Upgrade path =
tree-sitter.
"""
def extract_symbols(self, content: str, language: str) -> list[SymbolSpan]:
patterns = _REGEX_PATTERNS.get(language)
if not patterns:
return []
lines = content.splitlines()
# Collect (line_no, name, kind) tuples first, then compute end_line
# as the line before the next symbol starts (or EOF).
raw_hits: list[tuple[int, str, str]] = []
for kind, pattern in patterns:
for m in pattern.finditer(content):
# Convert match offset to 1-based line number.
line_no = content[: m.start()].count("\n") + 1
raw_hits.append((line_no, m.group(1), kind))
if not raw_hits:
return []
# Deduplicate: same (line_no, name) may appear for overlapping patterns.
seen: set[tuple[int, str]] = set()
unique: list[tuple[int, str, str]] = []
for line_no, name, kind in raw_hits:
key = (line_no, name)
if key in seen:
continue
seen.add(key)
unique.append((line_no, name, kind))
unique.sort(key=lambda x: x[0])
spans: list[SymbolSpan] = []
for i, (start_line, name, kind) in enumerate(unique):
if i + 1 < len(unique):
# End at line before next symbol starts, capped at file length.
end_line = unique[i + 1][0] - 1
else:
end_line = len(lines)
if end_line < start_line:
end_line = start_line
spans.append(SymbolSpan(name=name, kind=kind, start_line=start_line, end_line=end_line))
return spans
# ---------------------------------------------------------------------------
# Dispatch by file extension
# ---------------------------------------------------------------------------
_EXTENSION_LANGUAGE: dict[str, str] = {
".py": "py",
".ts": "ts",
".tsx": "ts",
".js": "js",
".jsx": "js",
".mjs": "js",
".cjs": "js",
".go": "go",
".rs": "rs",
".java": "java",
}
_DEFAULT_EXTRACTOR = AstSymbolExtractor()
_REGEX_EXTRACTOR = RegexSymbolExtractor()
def language_for_extension(ext: str) -> str:
"""Return the language key for a file extension (with or without leading dot).
Returns "" for unsupported extensions.
"""
if not ext.startswith("."):
ext = "." + ext
return _EXTENSION_LANGUAGE.get(ext.lower(), "")
def get_extractor(language: str) -> SymbolExtractor | None:
"""Return the appropriate extractor for `language`, or None if unsupported."""
if language == "py":
return _DEFAULT_EXTRACTOR
if language in _REGEX_PATTERNS:
return _REGEX_EXTRACTOR
return None
def extract_symbols_from_file(path: Path) -> tuple[list[SymbolSpan], str]:
"""Read a file and return (symbols, language).
Returns ([], "") if the extension is unsupported or the file cannot be read.
Never raises — callers use this for fallback routing.
"""
ext = path.suffix.lower()
language = language_for_extension(ext)
if not language:
return [], ""
try:
content = path.read_text(encoding="utf-8", errors="replace")
except OSError as e:
logger.debug("read failed for %s: %s", path, e)
return [], language
extractor = get_extractor(language)
if extractor is None:
return [], language
return extractor.extract_symbols(content, language), language