diff --git a/lizard_languages/golike.py b/lizard_languages/golike.py index a1314607..a78c0692 100644 --- a/lizard_languages/golike.py +++ b/lizard_languages/golike.py @@ -57,6 +57,11 @@ def _expect_function_dec(self, token): self.next(self._function_dec, token) elif token == "<": self.next(self._generalize, token) + elif token == "[": + # Square-bracket type parameters, e.g. Scala `def f[T](x: T)` or + # Go `func F[T any](x T)`. Skip them like the `<>` generics above + # so the following `(` params still register the function. + self.next(self._generalize_type_params, token) else: self._state = self._state_global @@ -64,6 +69,10 @@ def _expect_function_dec(self, token): def _generalize(self, tokens): pass + @CodeStateMachine.read_inside_brackets_then("[]", "_expect_function_dec") + def _generalize_type_params(self, tokens): + pass + @CodeStateMachine.read_inside_brackets_then("()", '_function_name') def _member_function(self, tokens): self.context.add_to_long_function_name(tokens) diff --git a/test/test_languages/testGo.py b/test/test_languages/testGo.py index 2c268ad6..7c4ec533 100644 --- a/test/test_languages/testGo.py +++ b/test/test_languages/testGo.py @@ -166,3 +166,17 @@ def test_sql_query_with_question_marks(self): self.assertEqual(1, len(result)) self.assertEqual("getQuery", result[0].name) self.assertEqual(1, result[0].cyclomatic_complexity) + + def test_generic_function_with_type_param(self): + result = get_go_function_list(''' + func Map[T any](x T) T { return x } + ''') + self.assertEqual(1, len(result)) + self.assertEqual("Map", result[0].name) + + def test_generic_function_with_multiple_type_params(self): + result = get_go_function_list(''' + func Reduce[T any, U any](xs []T, acc U) U { return acc } + ''') + self.assertEqual(1, len(result)) + self.assertEqual("Reduce", result[0].name) diff --git a/test/test_languages/testScala.py b/test/test_languages/testScala.py index e2aa41a6..f325e481 100644 --- a/test/test_languages/testScala.py +++ b/test/test_languages/testScala.py @@ -128,3 +128,30 @@ def list2(): Future[Seq[Person]] = a + b ''') self.assertEqual(2, len(result)) self.assertEqual('list2', result[0].name) + + def test_generic_function_with_type_param(self): + result = get_scala_function_list(''' + def identity[T](x: T): T = { x } + ''') + self.assertEqual(1, len(result)) + self.assertEqual("identity", result[0].name) + self.assertEqual(1, result[0].parameter_count) + + def test_generic_function_with_multiple_type_params(self): + result = get_scala_function_list(''' + def merge[A, B](a: A, b: B): String = { a.toString + b.toString } + ''') + self.assertEqual(1, len(result)) + self.assertEqual("merge", result[0].name) + self.assertEqual(2, result[0].parameter_count) + + def test_generic_method_in_class_not_dropped(self): + result = get_scala_function_list(''' + class Box { + def plain(x: Int): Int = { x + 1 } + def generic[T](x: T): T = { x } + def after(y: Int): Int = { y * 2 } + } + ''') + names = sorted(f.name for f in result) + self.assertEqual(["after", "generic", "plain"], names)