"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor. Covers R22 (file reading supports symbol/function granularity) and KTD1 (Python ast + language-aware regex, no tree-sitter dependency). """ from __future__ import annotations import textwrap import pytest from agentkit.tools.symbol_extractor import ( AstSymbolExtractor, RegexSymbolExtractor, SymbolSpan, extract_symbols_from_file, get_extractor, language_for_extension, ) # --------------------------------------------------------------------------- # language_for_extension # --------------------------------------------------------------------------- class TestLanguageForExtension: def test_python_extensions(self): assert language_for_extension("py") == "py" assert language_for_extension(".py") == "py" assert language_for_extension(".PY") == "py" # case-insensitive def test_typescript_javascript(self): assert language_for_extension(".ts") == "ts" assert language_for_extension(".tsx") == "ts" assert language_for_extension(".js") == "js" assert language_for_extension(".jsx") == "js" assert language_for_extension(".mjs") == "js" assert language_for_extension(".cjs") == "js" def test_go_rust_java(self): assert language_for_extension(".go") == "go" assert language_for_extension(".rs") == "rs" assert language_for_extension(".java") == "java" def test_unsupported_returns_empty(self): assert language_for_extension(".md") == "" assert language_for_extension(".txt") == "" assert language_for_extension("") == "" assert language_for_extension(".unknown") == "" # --------------------------------------------------------------------------- # AstSymbolExtractor — Python # --------------------------------------------------------------------------- class TestAstSymbolExtractor: extractor = AstSymbolExtractor() def test_unsupported_language_returns_empty(self): assert self.extractor.extract_symbols("function foo() {}", "ts") == [] def test_syntax_error_returns_empty(self): # Never raises — callers rely on this for fallback routing. result = self.extractor.extract_symbols("def broken(:\n pass", "py") assert result == [] def test_top_level_function(self): content = "def my_func():\n return 42\n" spans = self.extractor.extract_symbols(content, "py") assert len(spans) == 1 span = spans[0] assert span.name == "my_func" assert span.kind == "function" assert span.start_line == 1 assert span.end_line == 2 def test_async_function(self): content = "async def fetch():\n return 1\n" spans = self.extractor.extract_symbols(content, "py") assert len(spans) == 1 assert spans[0].name == "fetch" assert spans[0].kind == "function" def test_top_level_class(self): content = textwrap.dedent(''' class MyClass: """docstring""" def method_a(self): return 1 async def method_b(self): return 2 ''').strip() spans = self.extractor.extract_symbols(content, "py") names = [s.name for s in spans] assert "MyClass" in names assert "method_a" in names assert "method_b" in names cls = next(s for s in spans if s.name == "MyClass") assert cls.kind == "class" assert cls.start_line == 1 # Class body extends through the last method's end_lineno. assert cls.end_line >= 7 def test_methods_classified_as_methods(self): content = textwrap.dedent(''' class Foo: def bar(self): pass def top_level(): pass ''').strip() spans = self.extractor.extract_symbols(content, "py") by_name = {s.name: s for s in spans} assert by_name["bar"].kind == "method" assert by_name["top_level"].kind == "function" def test_decorated_function(self): content = textwrap.dedent(''' @staticmethod def helper(): return "hi" ''').strip() spans = self.extractor.extract_symbols(content, "py") # Note: extractor uses node.lineno (def line) — decorators above are # excluded by design (matches user-visible symbol start at `def`). assert any(s.name == "helper" for s in spans) span = next(s for s in spans if s.name == "helper") assert span.start_line == 2 # the `def` line def test_nested_function(self): content = textwrap.dedent(''' def outer(): def inner(): return 1 return inner() ''').strip() spans = self.extractor.extract_symbols(content, "py") names = {s.name for s in spans} assert "outer" in names assert "inner" in names def test_empty_file(self): assert self.extractor.extract_symbols("", "py") == [] def test_no_symbols_in_docstring_only_file(self): content = '"""just a docstring"""\n' assert self.extractor.extract_symbols(content, "py") == [] # --------------------------------------------------------------------------- # RegexSymbolExtractor — TS/JS/Go/Rust/Java # --------------------------------------------------------------------------- class TestRegexSymbolExtractor: extractor = RegexSymbolExtractor() def test_unsupported_language_returns_empty(self): assert self.extractor.extract_symbols("def foo(): pass", "py") == [] assert self.extractor.extract_symbols("function foo() {}", "rb") == [] def test_typescript_function_declaration(self): content = textwrap.dedent(''' export function renderComponent(props: Props): JSX.Element { return
; } ''').strip() spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "renderComponent" and s.kind == "function" for s in spans) def test_typescript_async_function(self): content = "async function fetchData() {\n return await fetch();\n}\n" spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "fetchData" for s in spans) def test_typescript_arrow_function_const(self): content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n" spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "handleClick" for s in spans) def test_typescript_class(self): content = textwrap.dedent(''' export abstract class BaseService { abstract run(): void; } ''').strip() spans = self.extractor.extract_symbols(content, "ts") assert any(s.name == "BaseService" and s.kind == "class" for s in spans) def test_javascript_function(self): content = "function foo() {\n return 1;\n}\n" spans = self.extractor.extract_symbols(content, "js") assert any(s.name == "foo" for s in spans) def test_javascript_arrow_const(self): content = "const bar = () => 42;\n" spans = self.extractor.extract_symbols(content, "js") assert any(s.name == "bar" for s in spans) def test_go_function(self): content = textwrap.dedent(''' package main func HandleRequest(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) } func (s *Server) Start() { // method } ''').strip() spans = self.extractor.extract_symbols(content, "go") names = {s.name for s in spans} assert "HandleRequest" in names assert "Start" in names # method receiver pattern def test_go_struct(self): content = "type Server struct {\n Addr string\n}\n" spans = self.extractor.extract_symbols(content, "go") assert any(s.name == "Server" and s.kind == "struct" for s in spans) def test_rust_function(self): content = textwrap.dedent(''' pub fn process(input: &str) -> Result { Ok(input.len()) } async fn fetch() -> Bytes { unimplemented!() } ''').strip() spans = self.extractor.extract_symbols(content, "rs") names = {s.name for s in spans} assert "process" in names assert "fetch" in names def test_rust_struct(self): content = "pub struct Config {\n pub path: String,\n}\n" spans = self.extractor.extract_symbols(content, "rs") assert any(s.name == "Config" and s.kind == "struct" for s in spans) def test_rust_impl(self): content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n" spans = self.extractor.extract_symbols(content, "rs") assert any(s.name == "Config" and s.kind == "impl" for s in spans) def test_java_class(self): content = textwrap.dedent(''' package com.example; public class UserService { public User findById(long id) { return null; } } ''').strip() spans = self.extractor.extract_symbols(content, "java") assert any(s.name == "UserService" and s.kind == "class" for s in spans) def test_java_method(self): content = "public User findById(long id) {\n return null;\n}\n" spans = self.extractor.extract_symbols(content, "java") assert any(s.name == "findById" and s.kind == "function" for s in spans) def test_end_line_extends_to_next_symbol(self): # First symbol's end_line is the line before the second symbol starts. content = textwrap.dedent(''' function first() { return 1; } function second() { return 2; } ''').strip() spans = self.extractor.extract_symbols(content, "js") spans.sort(key=lambda s: s.start_line) first = spans[0] second = spans[1] assert first.name == "first" assert second.name == "second" assert first.end_line == second.start_line - 1 def test_last_symbol_end_line_is_eof(self): content = "function only() {\n return 1;\n}\n" spans = self.extractor.extract_symbols(content, "js") assert len(spans) == 1 assert spans[0].end_line == len(content.splitlines()) # --------------------------------------------------------------------------- # get_extractor + integration # --------------------------------------------------------------------------- class TestGetExtractor: def test_python_returns_ast_extractor(self): ext = get_extractor("py") assert ext is not None assert isinstance(ext, AstSymbolExtractor) def test_typescript_returns_regex_extractor(self): ext = get_extractor("ts") assert ext is not None assert isinstance(ext, RegexSymbolExtractor) def test_unsupported_returns_none(self): assert get_extractor("md") is None assert get_extractor("") is None assert get_extractor("unknown") is None class TestExtractSymbolsFromFile: def test_python_file(self, tmp_path): path = tmp_path / "module.py" path.write_text("def hello():\n return 'world'\n", encoding="utf-8") spans, lang = extract_symbols_from_file(path) assert lang == "py" assert any(s.name == "hello" for s in spans) def test_unsupported_extension(self, tmp_path): path = tmp_path / "notes.md" path.write_text("# Hello\n", encoding="utf-8") spans, lang = extract_symbols_from_file(path) assert lang == "" assert spans == [] def test_missing_file_returns_empty(self, tmp_path): path = tmp_path / "nonexistent.py" spans, lang = extract_symbols_from_file(path) # lang is detected from extension even if read fails. assert lang == "py" assert spans == [] # --------------------------------------------------------------------------- # SymbolSpan dataclass # --------------------------------------------------------------------------- class TestSymbolSpan: def test_frozen_dataclass(self): span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3) assert span.name == "foo" with pytest.raises(Exception): span.name = "bar" # type: ignore[misc] — frozen def test_equality(self): a = SymbolSpan("foo", "function", 1, 3) b = SymbolSpan("foo", "function", 1, 3) assert a == b assert hash(a) == hash(b)