360 lines
13 KiB
Python
360 lines
13 KiB
Python
"""Unit tests for SymbolExtractor — AstSymbolExtractor + RegexSymbolExtractor.
|
|
|
|
Covers R22 (file reading supports symbol/function granularity) and KTD1
|
|
(Python ast + language-aware regex, no tree-sitter dependency).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import textwrap
|
|
|
|
import pytest
|
|
|
|
from agentkit.tools.symbol_extractor import (
|
|
AstSymbolExtractor,
|
|
RegexSymbolExtractor,
|
|
SymbolSpan,
|
|
extract_symbols_from_file,
|
|
get_extractor,
|
|
language_for_extension,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# language_for_extension
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestLanguageForExtension:
|
|
def test_python_extensions(self):
|
|
assert language_for_extension("py") == "py"
|
|
assert language_for_extension(".py") == "py"
|
|
assert language_for_extension(".PY") == "py" # case-insensitive
|
|
|
|
def test_typescript_javascript(self):
|
|
assert language_for_extension(".ts") == "ts"
|
|
assert language_for_extension(".tsx") == "ts"
|
|
assert language_for_extension(".js") == "js"
|
|
assert language_for_extension(".jsx") == "js"
|
|
assert language_for_extension(".mjs") == "js"
|
|
assert language_for_extension(".cjs") == "js"
|
|
|
|
def test_go_rust_java(self):
|
|
assert language_for_extension(".go") == "go"
|
|
assert language_for_extension(".rs") == "rs"
|
|
assert language_for_extension(".java") == "java"
|
|
|
|
def test_unsupported_returns_empty(self):
|
|
assert language_for_extension(".md") == ""
|
|
assert language_for_extension(".txt") == ""
|
|
assert language_for_extension("") == ""
|
|
assert language_for_extension(".unknown") == ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AstSymbolExtractor — Python
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAstSymbolExtractor:
|
|
extractor = AstSymbolExtractor()
|
|
|
|
def test_unsupported_language_returns_empty(self):
|
|
assert self.extractor.extract_symbols("function foo() {}", "ts") == []
|
|
|
|
def test_syntax_error_returns_empty(self):
|
|
# Never raises — callers rely on this for fallback routing.
|
|
result = self.extractor.extract_symbols("def broken(:\n pass", "py")
|
|
assert result == []
|
|
|
|
def test_top_level_function(self):
|
|
content = "def my_func():\n return 42\n"
|
|
spans = self.extractor.extract_symbols(content, "py")
|
|
assert len(spans) == 1
|
|
span = spans[0]
|
|
assert span.name == "my_func"
|
|
assert span.kind == "function"
|
|
assert span.start_line == 1
|
|
assert span.end_line == 2
|
|
|
|
def test_async_function(self):
|
|
content = "async def fetch():\n return 1\n"
|
|
spans = self.extractor.extract_symbols(content, "py")
|
|
assert len(spans) == 1
|
|
assert spans[0].name == "fetch"
|
|
assert spans[0].kind == "function"
|
|
|
|
def test_top_level_class(self):
|
|
content = textwrap.dedent('''
|
|
class MyClass:
|
|
"""docstring"""
|
|
|
|
def method_a(self):
|
|
return 1
|
|
|
|
async def method_b(self):
|
|
return 2
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "py")
|
|
names = [s.name for s in spans]
|
|
assert "MyClass" in names
|
|
assert "method_a" in names
|
|
assert "method_b" in names
|
|
|
|
cls = next(s for s in spans if s.name == "MyClass")
|
|
assert cls.kind == "class"
|
|
assert cls.start_line == 1
|
|
# Class body extends through the last method's end_lineno.
|
|
assert cls.end_line >= 7
|
|
|
|
def test_methods_classified_as_methods(self):
|
|
content = textwrap.dedent('''
|
|
class Foo:
|
|
def bar(self):
|
|
pass
|
|
|
|
def top_level():
|
|
pass
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "py")
|
|
by_name = {s.name: s for s in spans}
|
|
assert by_name["bar"].kind == "method"
|
|
assert by_name["top_level"].kind == "function"
|
|
|
|
def test_decorated_function(self):
|
|
content = textwrap.dedent('''
|
|
@staticmethod
|
|
def helper():
|
|
return "hi"
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "py")
|
|
# Note: extractor uses node.lineno (def line) — decorators above are
|
|
# excluded by design (matches user-visible symbol start at `def`).
|
|
assert any(s.name == "helper" for s in spans)
|
|
span = next(s for s in spans if s.name == "helper")
|
|
assert span.start_line == 2 # the `def` line
|
|
|
|
def test_nested_function(self):
|
|
content = textwrap.dedent('''
|
|
def outer():
|
|
def inner():
|
|
return 1
|
|
return inner()
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "py")
|
|
names = {s.name for s in spans}
|
|
assert "outer" in names
|
|
assert "inner" in names
|
|
|
|
def test_empty_file(self):
|
|
assert self.extractor.extract_symbols("", "py") == []
|
|
|
|
def test_no_symbols_in_docstring_only_file(self):
|
|
content = '"""just a docstring"""\n'
|
|
assert self.extractor.extract_symbols(content, "py") == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RegexSymbolExtractor — TS/JS/Go/Rust/Java
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRegexSymbolExtractor:
|
|
extractor = RegexSymbolExtractor()
|
|
|
|
def test_unsupported_language_returns_empty(self):
|
|
assert self.extractor.extract_symbols("def foo(): pass", "py") == []
|
|
assert self.extractor.extract_symbols("function foo() {}", "rb") == []
|
|
|
|
def test_typescript_function_declaration(self):
|
|
content = textwrap.dedent('''
|
|
export function renderComponent(props: Props): JSX.Element {
|
|
return <div/>;
|
|
}
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "ts")
|
|
assert any(s.name == "renderComponent" and s.kind == "function" for s in spans)
|
|
|
|
def test_typescript_async_function(self):
|
|
content = "async function fetchData() {\n return await fetch();\n}\n"
|
|
spans = self.extractor.extract_symbols(content, "ts")
|
|
assert any(s.name == "fetchData" for s in spans)
|
|
|
|
def test_typescript_arrow_function_const(self):
|
|
content = "const handleClick = (e: Event) => {\n console.log(e);\n};\n"
|
|
spans = self.extractor.extract_symbols(content, "ts")
|
|
assert any(s.name == "handleClick" for s in spans)
|
|
|
|
def test_typescript_class(self):
|
|
content = textwrap.dedent('''
|
|
export abstract class BaseService {
|
|
abstract run(): void;
|
|
}
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "ts")
|
|
assert any(s.name == "BaseService" and s.kind == "class" for s in spans)
|
|
|
|
def test_javascript_function(self):
|
|
content = "function foo() {\n return 1;\n}\n"
|
|
spans = self.extractor.extract_symbols(content, "js")
|
|
assert any(s.name == "foo" for s in spans)
|
|
|
|
def test_javascript_arrow_const(self):
|
|
content = "const bar = () => 42;\n"
|
|
spans = self.extractor.extract_symbols(content, "js")
|
|
assert any(s.name == "bar" for s in spans)
|
|
|
|
def test_go_function(self):
|
|
content = textwrap.dedent('''
|
|
package main
|
|
|
|
func HandleRequest(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
}
|
|
|
|
func (s *Server) Start() {
|
|
// method
|
|
}
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "go")
|
|
names = {s.name for s in spans}
|
|
assert "HandleRequest" in names
|
|
assert "Start" in names # method receiver pattern
|
|
|
|
def test_go_struct(self):
|
|
content = "type Server struct {\n Addr string\n}\n"
|
|
spans = self.extractor.extract_symbols(content, "go")
|
|
assert any(s.name == "Server" and s.kind == "struct" for s in spans)
|
|
|
|
def test_rust_function(self):
|
|
content = textwrap.dedent('''
|
|
pub fn process(input: &str) -> Result<usize, Error> {
|
|
Ok(input.len())
|
|
}
|
|
|
|
async fn fetch() -> Bytes {
|
|
unimplemented!()
|
|
}
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "rs")
|
|
names = {s.name for s in spans}
|
|
assert "process" in names
|
|
assert "fetch" in names
|
|
|
|
def test_rust_struct(self):
|
|
content = "pub struct Config {\n pub path: String,\n}\n"
|
|
spans = self.extractor.extract_symbols(content, "rs")
|
|
assert any(s.name == "Config" and s.kind == "struct" for s in spans)
|
|
|
|
def test_rust_impl(self):
|
|
content = "impl Config {\n pub fn new() -> Self { Self { path: String::new() } }\n}\n"
|
|
spans = self.extractor.extract_symbols(content, "rs")
|
|
assert any(s.name == "Config" and s.kind == "impl" for s in spans)
|
|
|
|
def test_java_class(self):
|
|
content = textwrap.dedent('''
|
|
package com.example;
|
|
|
|
public class UserService {
|
|
public User findById(long id) {
|
|
return null;
|
|
}
|
|
}
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "java")
|
|
assert any(s.name == "UserService" and s.kind == "class" for s in spans)
|
|
|
|
def test_java_method(self):
|
|
content = "public User findById(long id) {\n return null;\n}\n"
|
|
spans = self.extractor.extract_symbols(content, "java")
|
|
assert any(s.name == "findById" and s.kind == "function" for s in spans)
|
|
|
|
def test_end_line_extends_to_next_symbol(self):
|
|
# First symbol's end_line is the line before the second symbol starts.
|
|
content = textwrap.dedent('''
|
|
function first() {
|
|
return 1;
|
|
}
|
|
|
|
function second() {
|
|
return 2;
|
|
}
|
|
''').strip()
|
|
spans = self.extractor.extract_symbols(content, "js")
|
|
spans.sort(key=lambda s: s.start_line)
|
|
first = spans[0]
|
|
second = spans[1]
|
|
assert first.name == "first"
|
|
assert second.name == "second"
|
|
assert first.end_line == second.start_line - 1
|
|
|
|
def test_last_symbol_end_line_is_eof(self):
|
|
content = "function only() {\n return 1;\n}\n"
|
|
spans = self.extractor.extract_symbols(content, "js")
|
|
assert len(spans) == 1
|
|
assert spans[0].end_line == len(content.splitlines())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_extractor + integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetExtractor:
|
|
def test_python_returns_ast_extractor(self):
|
|
ext = get_extractor("py")
|
|
assert ext is not None
|
|
assert isinstance(ext, AstSymbolExtractor)
|
|
|
|
def test_typescript_returns_regex_extractor(self):
|
|
ext = get_extractor("ts")
|
|
assert ext is not None
|
|
assert isinstance(ext, RegexSymbolExtractor)
|
|
|
|
def test_unsupported_returns_none(self):
|
|
assert get_extractor("md") is None
|
|
assert get_extractor("") is None
|
|
assert get_extractor("unknown") is None
|
|
|
|
|
|
class TestExtractSymbolsFromFile:
|
|
def test_python_file(self, tmp_path):
|
|
path = tmp_path / "module.py"
|
|
path.write_text("def hello():\n return 'world'\n", encoding="utf-8")
|
|
spans, lang = extract_symbols_from_file(path)
|
|
assert lang == "py"
|
|
assert any(s.name == "hello" for s in spans)
|
|
|
|
def test_unsupported_extension(self, tmp_path):
|
|
path = tmp_path / "notes.md"
|
|
path.write_text("# Hello\n", encoding="utf-8")
|
|
spans, lang = extract_symbols_from_file(path)
|
|
assert lang == ""
|
|
assert spans == []
|
|
|
|
def test_missing_file_returns_empty(self, tmp_path):
|
|
path = tmp_path / "nonexistent.py"
|
|
spans, lang = extract_symbols_from_file(path)
|
|
# lang is detected from extension even if read fails.
|
|
assert lang == "py"
|
|
assert spans == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SymbolSpan dataclass
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSymbolSpan:
|
|
def test_frozen_dataclass(self):
|
|
span = SymbolSpan(name="foo", kind="function", start_line=1, end_line=3)
|
|
assert span.name == "foo"
|
|
with pytest.raises(Exception):
|
|
span.name = "bar" # type: ignore[misc] — frozen
|
|
|
|
def test_equality(self):
|
|
a = SymbolSpan("foo", "function", 1, 3)
|
|
b = SymbolSpan("foo", "function", 1, 3)
|
|
assert a == b
|
|
assert hash(a) == hash(b)
|