fischer-agentkit/tests/unit/test_symbol_extractor.py

360 lines
13 KiB
Python

"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
Covers R22 (file reading supports symbol/function granularity) and KTD1
(Python ast + language-aware regex, no tree-sitter dependency).
"""
from __future__ import annotations
import textwrap
import pytest
from agentkit.tools.symbol_extractor import (
AstSymbolExtractor,
RegexSymbolExtractor,
SymbolSpan,
extract_symbols_from_file,
get_extractor,
language_for_extension,
)
# ---------------------------------------------------------------------------
# language_for_extension
# ---------------------------------------------------------------------------
class TestLanguageForExtension:
def test_python_extensions(self):
assert language_for_extension("py") == "py"
assert language_for_extension(".py") == "py"
assert language_for_extension(".PY") == "py" # case-insensitive
def test_typescript_javascript(self):
assert language_for_extension(".ts") == "ts"
assert language_for_extension(".tsx") == "ts"
assert language_for_extension(".js") == "js"
assert language_for_extension(".jsx") == "js"
assert language_for_extension(".mjs") == "js"
assert language_for_extension(".cjs") == "js"
def test_go_rust_java(self):
assert language_for_extension(".go") == "go"
assert language_for_extension(".rs") == "rs"
assert language_for_extension(".java") == "java"
def test_unsupported_returns_empty(self):
assert language_for_extension(".md") == ""
assert language_for_extension(".txt") == ""
assert language_for_extension("") == ""
assert language_for_extension(".unknown") == ""
# ---------------------------------------------------------------------------
# AstSymbolExtractor — Python
# ---------------------------------------------------------------------------
class TestAstSymbolExtractor:
extractor = AstSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
def test_syntax_error_returns_empty(self):
# Never raises — callers rely on this for fallback routing.
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
assert result == []
def test_top_level_function(self):
content = "def my_func():\n return 42\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
span = spans[0]
assert span.name == "my_func"
assert span.kind == "function"
assert span.start_line == 1
assert span.end_line == 2
def test_async_function(self):
content = "async def fetch():\n return 1\n"
spans = self.extractor.extract_symbols(content, "py")
assert len(spans) == 1
assert spans[0].name == "fetch"
assert spans[0].kind == "function"
def test_top_level_class(self):
content = textwrap.dedent('''
class MyClass:
"""docstring"""
def method_a(self):
return 1
async def method_b(self):
return 2
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = [s.name for s in spans]
assert "MyClass" in names
assert "method_a" in names
assert "method_b" in names
cls = next(s for s in spans if s.name == "MyClass")
assert cls.kind == "class"
assert cls.start_line == 1
# Class body extends through the last method's end_lineno.
assert cls.end_line >= 7
def test_methods_classified_as_methods(self):
content = textwrap.dedent('''
class Foo:
def bar(self):
pass
def top_level():
pass
''').strip()
spans = self.extractor.extract_symbols(content, "py")
by_name = {s.name: s for s in spans}
assert by_name["bar"].kind == "method"
assert by_name["top_level"].kind == "function"
def test_decorated_function(self):
content = textwrap.dedent('''
@staticmethod
def helper():
return "hi"
''').strip()
spans = self.extractor.extract_symbols(content, "py")
# Note: extractor uses node.lineno (def line) — decorators above are
# excluded by design (matches user-visible symbol start at `def`).
assert any(s.name == "helper" for s in spans)
span = next(s for s in spans if s.name == "helper")
assert span.start_line == 2 # the `def` line
def test_nested_function(self):
content = textwrap.dedent('''
def outer():
def inner():
return 1
return inner()
''').strip()
spans = self.extractor.extract_symbols(content, "py")
names = {s.name for s in spans}
assert "outer" in names
assert "inner" in names
def test_empty_file(self):
assert self.extractor.extract_symbols("", "py") == []
def test_no_symbols_in_docstring_only_file(self):
content = '"""just a docstring"""\n'
assert self.extractor.extract_symbols(content, "py") == []
# ---------------------------------------------------------------------------
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
# ---------------------------------------------------------------------------
class TestRegexSymbolExtractor:
extractor = RegexSymbolExtractor()
def test_unsupported_language_returns_empty(self):
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
def test_typescript_function_declaration(self):
content = textwrap.dedent('''
export function renderComponent(props: Props): JSX.Element {
return <div/>;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
def test_typescript_async_function(self):
content = "async function fetchData() {\n return await fetch();\n}\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "fetchData" for s in spans)
def test_typescript_arrow_function_const(self):
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "handleClick" for s in spans)
def test_typescript_class(self):
content = textwrap.dedent('''
export abstract class BaseService {
abstract run(): void;
}
''').strip()
spans = self.extractor.extract_symbols(content, "ts")
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
def test_javascript_function(self):
content = "function foo() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "foo" for s in spans)
def test_javascript_arrow_const(self):
content = "const bar = () => 42;\n"
spans = self.extractor.extract_symbols(content, "js")
assert any(s.name == "bar" for s in spans)
def test_go_function(self):
content = textwrap.dedent('''
package main
func HandleRequest(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}
func (s *Server) Start() {
// method
}
''').strip()
spans = self.extractor.extract_symbols(content, "go")
names = {s.name for s in spans}
assert "HandleRequest" in names
assert "Start" in names # method receiver pattern
def test_go_struct(self):
content = "type Server struct {\n Addr string\n}\n"
spans = self.extractor.extract_symbols(content, "go")
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
def test_rust_function(self):
content = textwrap.dedent('''
pub fn process(input: &str) -> Result<usize, Error> {
Ok(input.len())
}
async fn fetch() -> Bytes {
unimplemented!()
}
''').strip()
spans = self.extractor.extract_symbols(content, "rs")
names = {s.name for s in spans}
assert "process" in names
assert "fetch" in names
def test_rust_struct(self):
content = "pub struct Config {\n pub path: String,\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
def test_rust_impl(self):
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
spans = self.extractor.extract_symbols(content, "rs")
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
def test_java_class(self):
content = textwrap.dedent('''
package com.example;
public class UserService {
public User findById(long id) {
return null;
}
}
''').strip()
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
def test_java_method(self):
content = "public User findById(long id) {\n return null;\n}\n"
spans = self.extractor.extract_symbols(content, "java")
assert any(s.name == "findById" and s.kind == "function" for s in spans)
def test_end_line_extends_to_next_symbol(self):
# First symbol's end_line is the line before the second symbol starts.
content = textwrap.dedent('''
function first() {
return 1;
}
function second() {
return 2;
}
''').strip()
spans = self.extractor.extract_symbols(content, "js")
spans.sort(key=lambda s: s.start_line)
first = spans[0]
second = spans[1]
assert first.name == "first"
assert second.name == "second"
assert first.end_line == second.start_line - 1
def test_last_symbol_end_line_is_eof(self):
content = "function only() {\n return 1;\n}\n"
spans = self.extractor.extract_symbols(content, "js")
assert len(spans) == 1
assert spans[0].end_line == len(content.splitlines())
# ---------------------------------------------------------------------------
# get_extractor + integration
# ---------------------------------------------------------------------------
class TestGetExtractor:
def test_python_returns_ast_extractor(self):
ext = get_extractor("py")
assert ext is not None
assert isinstance(ext, AstSymbolExtractor)
def test_typescript_returns_regex_extractor(self):
ext = get_extractor("ts")
assert ext is not None
assert isinstance(ext, RegexSymbolExtractor)
def test_unsupported_returns_none(self):
assert get_extractor("md") is None
assert get_extractor("") is None
assert get_extractor("unknown") is None
class TestExtractSymbolsFromFile:
def test_python_file(self, tmp_path):
path = tmp_path / "module.py"
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == "py"
assert any(s.name == "hello" for s in spans)
def test_unsupported_extension(self, tmp_path):
path = tmp_path / "notes.md"
path.write_text("# Hello\n", encoding="utf-8")
spans, lang = extract_symbols_from_file(path)
assert lang == ""
assert spans == []
def test_missing_file_returns_empty(self, tmp_path):
path = tmp_path / "nonexistent.py"
spans, lang = extract_symbols_from_file(path)
# lang is detected from extension even if read fails.
assert lang == "py"
assert spans == []
# ---------------------------------------------------------------------------
# SymbolSpan dataclass
# ---------------------------------------------------------------------------
class TestSymbolSpan:
def test_frozen_dataclass(self):
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
assert span.name == "foo"
with pytest.raises(Exception):
span.name = "bar" # type: ignore[misc] — frozen
def test_equality(self):
a = SymbolSpan("foo", "function", 1, 3)
b = SymbolSpan("foo", "function", 1, 3)
assert a == b
assert hash(a) == hash(b)