From 33fba06e516d64cfd8b8d8433692b94c31ba9bf7 Mon Sep 17 00:00:00 2001 From: Fawaz Farid Date: Thu, 26 Mar 2026 08:04:25 +0300 Subject: [PATCH] Fix GetAllFunctions to support all languages GetAllFunctions previously only searched for JavaScript AST node types, causing it to miss functions in Python, Go, Ruby, and Rust files or fall back to incorrect JS node types silently. Add functionNodeTypes mapping for all supported languages (JS, TS, Python, Go, Ruby, Rust) with their respective AST node types. Refactor extraction logic to handle language-specific patterns for qualified method names (e.g., User.login for Go/Python, User#login for Ruby, User::login for Rust). Add language-specific helper functions: - getGoMethodName: Extract receiver type from method_declaration - getPythonFunctionName: Walk AST to find parent class - getRubyMethodName: Find parent class, use # separator - getRustFunctionName: Find impl block, use :: separator Add comprehensive test coverage. Rename existing test to TestGetAllFunctions_JavaScript for clarity. --- kai-core/detect/detect.go | 279 +++++++++++++++++++++++++-------- kai-core/detect/detect_test.go | 129 ++++++++++++++- 2 files changed, 341 insertions(+), 67 deletions(-) diff --git a/kai-core/detect/detect.go b/kai-core/detect/detect.go index e7c1b48b..c824bf0d 100644 --- a/kai-core/detect/detect.go +++ b/kai-core/detect/detect.go @@ -11,6 +11,37 @@ import ( "kai-core/parse" ) +// functionNodeTypes maps each language to the AST node types that represent +// function-like declarations (functions, methods, etc.) +var functionNodeTypes = map[string][]string{ + "js": { + "function_declaration", // function foo() {} + "method_definition", // class methods + "lexical_declaration", // const foo = () => {} + "variable_declaration", // var foo = function() {} + }, + "ts": { + "function_declaration", + "method_definition", + "lexical_declaration", + "variable_declaration", + }, + "py": { + "function_definition", // Both standalone functions and methods + }, + "go": { + "function_declaration", // func Foo() {} + "method_declaration", // func (T) Method() {} + }, + "rb": { + "method", // def foo + "singleton_method", // def self.foo + }, + "rs": { + "function_item", + }, +} + // ChangeCategory represents a type of change. type ChangeCategory string @@ -56,12 +87,12 @@ const ( DependencyUpdated ChangeCategory = "DEPENDENCY_UPDATED" // Semantic config changes - FeatureFlagChanged ChangeCategory = "FEATURE_FLAG_CHANGED" - TimeoutChanged ChangeCategory = "TIMEOUT_CHANGED" - LimitChanged ChangeCategory = "LIMIT_CHANGED" - RetryConfigChanged ChangeCategory = "RETRY_CONFIG_CHANGED" - EndpointChanged ChangeCategory = "ENDPOINT_CHANGED" - CredentialChanged ChangeCategory = "CREDENTIAL_CHANGED" + FeatureFlagChanged ChangeCategory = "FEATURE_FLAG_CHANGED" + TimeoutChanged ChangeCategory = "TIMEOUT_CHANGED" + LimitChanged ChangeCategory = "LIMIT_CHANGED" + RetryConfigChanged ChangeCategory = "RETRY_CONFIG_CHANGED" + EndpointChanged ChangeCategory = "ENDPOINT_CHANGED" + CredentialChanged ChangeCategory = "CREDENTIAL_CHANGED" // Schema/migration changes SchemaFieldAdded ChangeCategory = "SCHEMA_FIELD_ADDED" @@ -243,69 +274,57 @@ func GetAllFunctions(parsed *parse.ParsedFile, content []byte, lang ...string) m l = lang[0] } - switch l { - case "rb": - // Ruby: method and singleton_method nodes - for _, node := range parsed.FindNodesOfType("method") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } - for _, node := range parsed.FindNodesOfType("singleton_method") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs["self."+name] = &FuncInfo{Name: "self." + name, Node: node, Body: body} - } - } - - case "py": - // Python: function_definition nodes - for _, node := range parsed.FindNodesOfType("function_definition") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } - - default: - // JS/TS/Go and others - - // Function declarations: function foo() {} - for _, node := range parsed.FindNodesOfType("function_declaration") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } - - // Arrow functions assigned to variables: const foo = () => {} - for _, node := range parsed.FindNodesOfType("lexical_declaration") { - name, arrowNode := getArrowFunctionName(node, content) - if name != "" && arrowNode != nil { - body := getFunctionBody(arrowNode, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } + nodeTypes, ok := functionNodeTypes[l] + if !ok { + // Fallback to JS if language not in map + nodeTypes = functionNodeTypes["js"] + } - // Variable declarations: var foo = function() {} - for _, node := range parsed.FindNodesOfType("variable_declaration") { - name, funcNode := getVariableFunctionName(node, content) - if name != "" && funcNode != nil { - body := getFunctionBody(funcNode, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} + // Search for all node types for this language + for _, nodeType := range nodeTypes { + for _, node := range parsed.FindNodesOfType(nodeType) { + var name string + var bodyNode *sitter.Node + + // Handle special cases per node type + switch nodeType { + case "lexical_declaration": + // JS/TS: const foo = () => {} + name, bodyNode = getArrowFunctionName(node, content) + case "variable_declaration": + // JS/TS: var foo = function() {} + name, bodyNode = getVariableFunctionName(node, content) + case "singleton_method": + // Ruby: def self.foo + name = getFunctionName(node, content) + if name != "" { + name = "self." + name + } + bodyNode = node + case "method_declaration": + // Go: func (T) Method() {} + name = getGoMethodName(node, content) + bodyNode = node + case "function_definition": + // Python: check if inside a class + name = getPythonFunctionName(node, content) + bodyNode = node + case "method": + // Ruby: check if inside a class + name = getRubyMethodName(node, content) + bodyNode = node + case "function_item": + // Rust: check if inside impl block + name = getRustFunctionName(node, content) + bodyNode = node + default: + // Standard function extraction + name = getFunctionName(node, content) + bodyNode = node } - } - // Method definitions in classes/objects - for _, node := range parsed.FindNodesOfType("method_definition") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) + if name != "" && bodyNode != nil { + body := getFunctionBody(bodyNode, content) funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} } } @@ -722,6 +741,134 @@ func getFunctionName(node *sitter.Node, content []byte) string { return "" } +// getGoMethodName extracts the qualified name for a Go method (e.g., "User.login") +func getGoMethodName(node *sitter.Node, content []byte) string { + var receiverType string + var methodName string + + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + switch child.Type() { + case "parameter_list": + // First parameter_list is the receiver + if receiverType == "" { + receiverType = getGoReceiverType(child, content) + } + case "field_identifier": + methodName = parse.GetNodeContent(child, content) + } + } + + if methodName == "" { + return "" + } + + if receiverType != "" { + return receiverType + "." + methodName + } + return methodName +} + +// getGoReceiverType extracts the type from a Go method receiver parameter list +func getGoReceiverType(paramList *sitter.Node, content []byte) string { + for i := 0; i < int(paramList.ChildCount()); i++ { + child := paramList.Child(i) + if child.Type() == "parameter_declaration" { + for j := 0; j < int(child.ChildCount()); j++ { + typeChild := child.Child(j) + switch typeChild.Type() { + case "type_identifier": + return parse.GetNodeContent(typeChild, content) + case "pointer_type": + // Extract base type from pointer (e.g., "*User" -> "User") + for k := 0; k < int(typeChild.ChildCount()); k++ { + ptrChild := typeChild.Child(k) + if ptrChild.Type() == "type_identifier" { + return parse.GetNodeContent(ptrChild, content) + } + } + } + } + } + } + return "" +} + +// getPythonFunctionName extracts qualified name for Python functions (e.g., "User.login") +func getPythonFunctionName(node *sitter.Node, content []byte) string { + funcName := getFunctionName(node, content) + if funcName == "" { + return "" + } + + // Check if this function is inside a class + parent := node.Parent() + for parent != nil { + if parent.Type() == "class_definition" { + // Found parent class, get its name + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "identifier" { + className := parse.GetNodeContent(child, content) + return className + "." + funcName + } + } + } + parent = parent.Parent() + } + return funcName +} + +// getRubyMethodName extracts qualified name for Ruby methods (e.g., "User#login") +func getRubyMethodName(node *sitter.Node, content []byte) string { + methodName := getFunctionName(node, content) + if methodName == "" { + return "" + } + + // Check if this method is inside a class + parent := node.Parent() + for parent != nil { + if parent.Type() == "class" { + // Found parent class, get its name + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "constant" { + className := parse.GetNodeContent(child, content) + return className + "#" + methodName + } + } + } + parent = parent.Parent() + } + return methodName +} + +// getRustFunctionName extracts qualified name for Rust functions (e.g., "User::login") +func getRustFunctionName(node *sitter.Node, content []byte) string { + funcName := getFunctionName(node, content) + if funcName == "" { + return "" + } + + // Check if this function is inside an impl block + parent := node.Parent() + for parent != nil { + if parent.Type() == "impl_item" { + // Found impl block, get the type name + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "type_identifier" { + typeName := parse.GetNodeContent(child, content) + return typeName + "::" + funcName + } + } + } + parent = parent.Parent() + } + return funcName +} + func getFunctionParams(node *sitter.Node, content []byte) string { for i := 0; i < int(node.ChildCount()); i++ { child := node.Child(i) diff --git a/kai-core/detect/detect_test.go b/kai-core/detect/detect_test.go index 2265d394..95585a3f 100644 --- a/kai-core/detect/detect_test.go +++ b/kai-core/detect/detect_test.go @@ -375,7 +375,7 @@ func TestEqualStringSlices(t *testing.T) { } } -func TestGetAllFunctions(t *testing.T) { +func TestGetAllFunctions_JavaScript(t *testing.T) { parser := parse.NewParser() content := []byte(` function regular() {} @@ -402,6 +402,133 @@ class MyClass { } } +func TestGetAllFunctions_TypeScript(t *testing.T) { + parser := parse.NewParser() + content := []byte(`function greet(): void {} + +class User { + login(): void {} +} +`) + + parsed, err := parser.Parse(content, "ts") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "ts") + + expectedFuncs := []string{"greet", "login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Go(t *testing.T) { + parser := parse.NewParser() + content := []byte(`package main + +func greet() {} + +type User struct{} + +func (u *User) login() {} +`) + + parsed, err := parser.Parse(content, "go") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "go") + + expectedFuncs := []string{"greet", "User.login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Python(t *testing.T) { + parser := parse.NewParser() + content := []byte(`def greet(): + pass + +class User: + def login(self): + pass +`) + + parsed, err := parser.Parse(content, "py") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "py") + + expectedFuncs := []string{"greet", "User.login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Ruby(t *testing.T) { + parser := parse.NewParser() + content := []byte(`def greet +end + +class User + def login + end +end +`) + + parsed, err := parser.Parse(content, "rb") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "rb") + + expectedFuncs := []string{"greet", "User#login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Rust(t *testing.T) { + parser := parse.NewParser() + content := []byte(`fn greet() {} + +struct User {} + +impl User { + fn login(&self) {} +} +`) + + parsed, err := parser.Parse(content, "rs") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "rs") + + expectedFuncs := []string{"greet", "User::login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + func TestGetArrowFunctionName(t *testing.T) { parser := parse.NewParser() content := []byte(`const myArrow = (x) => x * 2;`)