"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor. Covers R22 (file reading supports symbol/function granularity) and KTD1 (Python ast + language-aware regex, no tree-sitter dependency). """ from __future__ import annotations import textwrap import pytest from agentkit.tools.symbol_extractor import ( AstSymbolExtractor, RegexSymbolExtractor, SymbolSpan, extract_symbols_from_file, get_extractor, language_for_extension, ) # --------------------------------------------------------------------------- # language_for_extension # --------------------------------------------------------------------------- class TestLanguageForExtension: def test_python_extensions(self): assert language_for_extension("py") == "py" assert language_for_extension(".py") == "py" assert language_for_extension(".PY") == "py" # case-insensitive def test_typescript_javascript(self): assert language_for_extension(".ts") == "ts" assert language_for_extension(".tsx") == "ts" assert language_for_extension(".js") == "js" assert language_for_extension(".jsx") == "js" assert language_for_extension(".mjs") == "js" assert language_for_extension(".cjs") == "js" def test_go_rust_java(self): assert language_for_extension(".go") == "go" assert language_for_extension(".rs") == "rs" assert language_for_extension(".java") == "java" def test_unsupported_returns_empty(self): assert language_for_extension(".md") == "" assert language_for_extension(".txt") == "" assert language_for_extension("") == "" assert language_for_extension(".unknown") == "" # --------------------------------------------------------------------------- # AstSymbolExtractor — Python # --------------------------------------------------------------------------- class TestAstSymbolExtractor: extractor = AstSymbolExtractor() def test_unsupported_language_returns_empty(self): assert self.extractor.extract_symbols("function foo() {}", "ts") == [] def test_syntax_error_returns_empty(self): # Never raises — callers rely on this for fallback routing. result = self.extractor.extract_symbols("def broken(:\n pass", "py") assert result == [] def test_top_level_function(self): content = "def my_func():\n return 42\n" spans = self.extractor.extract_symbols(content, "py") assert len(spans) == 1 span = spans[0] assert span.name == "my_func" assert span.kind == "function" assert span.start_line == 1 assert span.end_line == 2 def test_async_function(self): content = "async def fetch():\n return 1\n" spans = self.extractor.extract_symbols(content, "py") assert len(spans) == 1 assert spans[0].name == "fetch" assert spans[0].kind == "function" def test_top_level_class(self): content = textwrap.dedent(''' class MyClass: """docstring""" def method_a(self): return 1 async def method_b(self): return 2 ''').strip() spans = self.extractor.extract_symbols(content, "py") names = [s.name for s in spans] assert "MyClass" in names assert "method_a" in names assert "method_b" in names cls = next(s for s in spans if s.name == "MyClass") assert cls.kind == "class" assert cls.start_line == 1 # Class body extends through the last method's end_lineno. assert cls.end_line >= 7 def test_methods_classified_as_methods(self): content = textwrap.dedent(''' class Foo: def bar(self): pass def top_level(): pass ''').strip() spans = self.extractor.extract_symbols(content, "py") by_name = {s.name: s for s in spans} assert by_name["bar"].kind == "method" assert by_name["top_level"].kind == "function" def test_decorated_function(self): content = textwrap.dedent(''' @staticmethod def helper(): return "hi" ''').strip() spans = self.extractor.extract_symbols(content, "py") # Note: extractor uses node.lineno (def line) — decorators above are # excluded by design (matches user-visible symbol start at `def`). assert any(s.name == "helper" for s in spans) span = next(s for s in spans if s.name == "helper") assert span.start_line == 2 # the `def` line def test_nested_function(self): content = textwrap.dedent(''' def outer(): def inner(): return 1 return inner() ''').strip() spans = self.extractor.extract_symbols(content, "py") names = {s.name for s in spans} assert "outer" in names assert "inner" in names def test_empty_file(self): assert self.extractor.extract_symbols("", "py") == [] def test_no_symbols_in_docstring_only_file(self): content = '"""just a docstring"""\n' assert self.extractor.extract_symbols(content, "py") == [] # --------------------------------------------------------------------------- # RegexSymbolExtractor — TS/JS/Go/Rust/Java # --------------------------------------------------------------------------- class TestRegexSymbolExtractor: extractor = RegexSymbolExtractor() def test_unsupported_language_returns_empty(self): assert self.extractor.extract_symbols("def foo(): pass", "py") == [] assert self.extractor.extract_symbols("function foo() {}", "rb") == [] def test_typescript_function_declaration(self): content = textwrap.dedent(''' export function renderComponent(props: Props): JSX.Element { return
; } ''').strip() spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "renderComponent" and s.kind == "function" for s in spans) def test_typescript_async_function(self): content = "async function fetchData() {\n return await fetch();\n}\n" spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "fetchData" for s in spans) def test_typescript_arrow_function_const(self): content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n" spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "handleClick" for s in spans) def test_typescript_class(self): content = textwrap.dedent(''' export abstract class BaseService { abstract run(): void; } ''').strip() spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "BaseService" and s.kind == "class" for s in spans) def test_javascript_function(self): content = "function foo() {\n return 1;\n}\n" spans = self.extractor.extract_symbols(content, "js") assert any(s.name == "foo" for s in spans) def test_javascript_arrow_const(self): content = "const bar = () => 42;\n" spans = self.extractor.extract_symbols(content, "js") assert any(s.name == "bar" for s in spans) def test_go_function(self): content = textwrap.dedent(''' package main func HandleRequest(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) } func (s *Server) Start() { // method } ''').strip() spans = self.extractor.extract_symbols(content, "go") names = {s.name for s in spans} assert "HandleRequest" in names assert "Start" in names # method receiver pattern def test_go_struct(self): content = "type Server struct {\n Addr string\n}\n" spans = self.extractor.extract_symbols(content, "go") assert any(s.name == "Server" and s.kind == "struct" for s in spans) def test_rust_function(self): content = textwrap.dedent(''' pub fn process(input: &str) -> Result