diff --git a/kai-core/detect/detect.go b/kai-core/detect/detect.go index e7c1b48b..c824bf0d 100644 --- a/kai-core/detect/detect.go +++ b/kai-core/detect/detect.go @@ -11,6 +11,37 @@ import ( "kai-core/parse" ) +// functionNodeTypes maps each language to the AST node types that represent +// function-like declarations (functions, methods, etc.) +var functionNodeTypes = map[string][]string{ + "js": { + "function_declaration", // function foo() {} + "method_definition", // class methods + "lexical_declaration", // const foo = () => {} + "variable_declaration", // var foo = function() {} + }, + "ts": { + "function_declaration", + "method_definition", + "lexical_declaration", + "variable_declaration", + }, + "py": { + "function_definition", // Both standalone functions and methods + }, + "go": { + "function_declaration", // func Foo() {} + "method_declaration", // func (T) Method() {} + }, + "rb": { + "method", // def foo + "singleton_method", // def self.foo + }, + "rs": { + "function_item", + }, +} + // ChangeCategory represents a type of change. type ChangeCategory string @@ -56,12 +87,12 @@ const ( DependencyUpdated ChangeCategory = "DEPENDENCY_UPDATED" // Semantic config changes - FeatureFlagChanged ChangeCategory = "FEATURE_FLAG_CHANGED" - TimeoutChanged ChangeCategory = "TIMEOUT_CHANGED" - LimitChanged ChangeCategory = "LIMIT_CHANGED" - RetryConfigChanged ChangeCategory = "RETRY_CONFIG_CHANGED" - EndpointChanged ChangeCategory = "ENDPOINT_CHANGED" - CredentialChanged ChangeCategory = "CREDENTIAL_CHANGED" + FeatureFlagChanged ChangeCategory = "FEATURE_FLAG_CHANGED" + TimeoutChanged ChangeCategory = "TIMEOUT_CHANGED" + LimitChanged ChangeCategory = "LIMIT_CHANGED" + RetryConfigChanged ChangeCategory = "RETRY_CONFIG_CHANGED" + EndpointChanged ChangeCategory = "ENDPOINT_CHANGED" + CredentialChanged ChangeCategory = "CREDENTIAL_CHANGED" // Schema/migration changes SchemaFieldAdded ChangeCategory = "SCHEMA_FIELD_ADDED" @@ -243,69 +274,57 @@ func GetAllFunctions(parsed *parse.ParsedFile, content []byte, lang ...string) m l = lang[0] } - switch l { - case "rb": - // Ruby: method and singleton_method nodes - for _, node := range parsed.FindNodesOfType("method") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } - for _, node := range parsed.FindNodesOfType("singleton_method") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs["self."+name] = &FuncInfo{Name: "self." + name, Node: node, Body: body} - } - } - - case "py": - // Python: function_definition nodes - for _, node := range parsed.FindNodesOfType("function_definition") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } - - default: - // JS/TS/Go and others - - // Function declarations: function foo() {} - for _, node := range parsed.FindNodesOfType("function_declaration") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } - - // Arrow functions assigned to variables: const foo = () => {} - for _, node := range parsed.FindNodesOfType("lexical_declaration") { - name, arrowNode := getArrowFunctionName(node, content) - if name != "" && arrowNode != nil { - body := getFunctionBody(arrowNode, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} - } - } + nodeTypes, ok := functionNodeTypes[l] + if !ok { + // Fallback to JS if language not in map + nodeTypes = functionNodeTypes["js"] + } - // Variable declarations: var foo = function() {} - for _, node := range parsed.FindNodesOfType("variable_declaration") { - name, funcNode := getVariableFunctionName(node, content) - if name != "" && funcNode != nil { - body := getFunctionBody(funcNode, content) - funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} + // Search for all node types for this language + for _, nodeType := range nodeTypes { + for _, node := range parsed.FindNodesOfType(nodeType) { + var name string + var bodyNode *sitter.Node + + // Handle special cases per node type + switch nodeType { + case "lexical_declaration": + // JS/TS: const foo = () => {} + name, bodyNode = getArrowFunctionName(node, content) + case "variable_declaration": + // JS/TS: var foo = function() {} + name, bodyNode = getVariableFunctionName(node, content) + case "singleton_method": + // Ruby: def self.foo + name = getFunctionName(node, content) + if name != "" { + name = "self." + name + } + bodyNode = node + case "method_declaration": + // Go: func (T) Method() {} + name = getGoMethodName(node, content) + bodyNode = node + case "function_definition": + // Python: check if inside a class + name = getPythonFunctionName(node, content) + bodyNode = node + case "method": + // Ruby: check if inside a class + name = getRubyMethodName(node, content) + bodyNode = node + case "function_item": + // Rust: check if inside impl block + name = getRustFunctionName(node, content) + bodyNode = node + default: + // Standard function extraction + name = getFunctionName(node, content) + bodyNode = node } - } - // Method definitions in classes/objects - for _, node := range parsed.FindNodesOfType("method_definition") { - name := getFunctionName(node, content) - if name != "" { - body := getFunctionBody(node, content) + if name != "" && bodyNode != nil { + body := getFunctionBody(bodyNode, content) funcs[name] = &FuncInfo{Name: name, Node: node, Body: body} } } @@ -722,6 +741,134 @@ func getFunctionName(node *sitter.Node, content []byte) string { return "" } +// getGoMethodName extracts the qualified name for a Go method (e.g., "User.login") +func getGoMethodName(node *sitter.Node, content []byte) string { + var receiverType string + var methodName string + + for i := 0; i < int(node.ChildCount()); i++ { + child := node.Child(i) + switch child.Type() { + case "parameter_list": + // First parameter_list is the receiver + if receiverType == "" { + receiverType = getGoReceiverType(child, content) + } + case "field_identifier": + methodName = parse.GetNodeContent(child, content) + } + } + + if methodName == "" { + return "" + } + + if receiverType != "" { + return receiverType + "." + methodName + } + return methodName +} + +// getGoReceiverType extracts the type from a Go method receiver parameter list +func getGoReceiverType(paramList *sitter.Node, content []byte) string { + for i := 0; i < int(paramList.ChildCount()); i++ { + child := paramList.Child(i) + if child.Type() == "parameter_declaration" { + for j := 0; j < int(child.ChildCount()); j++ { + typeChild := child.Child(j) + switch typeChild.Type() { + case "type_identifier": + return parse.GetNodeContent(typeChild, content) + case "pointer_type": + // Extract base type from pointer (e.g., "*User" -> "User") + for k := 0; k < int(typeChild.ChildCount()); k++ { + ptrChild := typeChild.Child(k) + if ptrChild.Type() == "type_identifier" { + return parse.GetNodeContent(ptrChild, content) + } + } + } + } + } + } + return "" +} + +// getPythonFunctionName extracts qualified name for Python functions (e.g., "User.login") +func getPythonFunctionName(node *sitter.Node, content []byte) string { + funcName := getFunctionName(node, content) + if funcName == "" { + return "" + } + + // Check if this function is inside a class + parent := node.Parent() + for parent != nil { + if parent.Type() == "class_definition" { + // Found parent class, get its name + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "identifier" { + className := parse.GetNodeContent(child, content) + return className + "." + funcName + } + } + } + parent = parent.Parent() + } + return funcName +} + +// getRubyMethodName extracts qualified name for Ruby methods (e.g., "User#login") +func getRubyMethodName(node *sitter.Node, content []byte) string { + methodName := getFunctionName(node, content) + if methodName == "" { + return "" + } + + // Check if this method is inside a class + parent := node.Parent() + for parent != nil { + if parent.Type() == "class" { + // Found parent class, get its name + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "constant" { + className := parse.GetNodeContent(child, content) + return className + "#" + methodName + } + } + } + parent = parent.Parent() + } + return methodName +} + +// getRustFunctionName extracts qualified name for Rust functions (e.g., "User::login") +func getRustFunctionName(node *sitter.Node, content []byte) string { + funcName := getFunctionName(node, content) + if funcName == "" { + return "" + } + + // Check if this function is inside an impl block + parent := node.Parent() + for parent != nil { + if parent.Type() == "impl_item" { + // Found impl block, get the type name + for i := 0; i < int(parent.ChildCount()); i++ { + child := parent.Child(i) + if child.Type() == "type_identifier" { + typeName := parse.GetNodeContent(child, content) + return typeName + "::" + funcName + } + } + } + parent = parent.Parent() + } + return funcName +} + func getFunctionParams(node *sitter.Node, content []byte) string { for i := 0; i < int(node.ChildCount()); i++ { child := node.Child(i) diff --git a/kai-core/detect/detect_test.go b/kai-core/detect/detect_test.go index 2265d394..95585a3f 100644 --- a/kai-core/detect/detect_test.go +++ b/kai-core/detect/detect_test.go @@ -375,7 +375,7 @@ func TestEqualStringSlices(t *testing.T) { } } -func TestGetAllFunctions(t *testing.T) { +func TestGetAllFunctions_JavaScript(t *testing.T) { parser := parse.NewParser() content := []byte(` function regular() {} @@ -402,6 +402,133 @@ class MyClass { } } +func TestGetAllFunctions_TypeScript(t *testing.T) { + parser := parse.NewParser() + content := []byte(`function greet(): void {} + +class User { + login(): void {} +} +`) + + parsed, err := parser.Parse(content, "ts") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "ts") + + expectedFuncs := []string{"greet", "login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Go(t *testing.T) { + parser := parse.NewParser() + content := []byte(`package main + +func greet() {} + +type User struct{} + +func (u *User) login() {} +`) + + parsed, err := parser.Parse(content, "go") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "go") + + expectedFuncs := []string{"greet", "User.login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Python(t *testing.T) { + parser := parse.NewParser() + content := []byte(`def greet(): + pass + +class User: + def login(self): + pass +`) + + parsed, err := parser.Parse(content, "py") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "py") + + expectedFuncs := []string{"greet", "User.login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Ruby(t *testing.T) { + parser := parse.NewParser() + content := []byte(`def greet +end + +class User + def login + end +end +`) + + parsed, err := parser.Parse(content, "rb") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "rb") + + expectedFuncs := []string{"greet", "User#login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + +func TestGetAllFunctions_Rust(t *testing.T) { + parser := parse.NewParser() + content := []byte(`fn greet() {} + +struct User {} + +impl User { + fn login(&self) {} +} +`) + + parsed, err := parser.Parse(content, "rs") + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + funcs := GetAllFunctions(parsed, content, "rs") + + expectedFuncs := []string{"greet", "User::login"} + for _, expected := range expectedFuncs { + if _, ok := funcs[expected]; !ok { + t.Errorf("expected function %q not found", expected) + } + } +} + func TestGetArrowFunctionName(t *testing.T) { parser := parse.NewParser() content := []byte(`const myArrow = (x) => x * 2;`)