Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ext/store/maxcompute/sanitizer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package maxcompute

import "strings"

var reservedKeywords = []string{
"add", "after", "all", "alter", "analyze", "and", "archive", "array", "as", "asc",
"before", "between", "bigint", "binary", "blob", "boolean", "both", "decimal",
"bucket", "buckets", "by", "cascade", "case", "cast", "cfile", "change", "cluster",
"clustered", "clusterstatus", "collection", "column", "columns", "comment", "compute",
"concatenate", "continue", "create", "cross", "current", "cursor", "data", "database",
"databases", "date", "datetime", "dbproperties", "deferred", "delete", "delimited",
"desc", "describe", "directory", "disable", "distinct", "distribute", "double", "drop",
"else", "enable", "end", "except", "escaped", "exclusive", "exists", "explain", "export",
"extended", "external", "false", "fetch", "fields", "fileformat", "first", "float",
"following", "format", "formatted", "from", "full", "function", "functions", "grant",
"group", "having", "hold_ddltime", "idxproperties", "if", "import", "in", "index",
"indexes", "inpath", "inputdriver", "inputformat", "insert", "int", "intersect", "into",
"is", "items", "join", "keys", "lateral", "left", "lifecycle", "like", "limit", "lines",
"load", "local", "location", "lock", "locks", "long", "map", "mapjoin", "materialized",
"minus", "msck", "not", "no_drop", "null", "of", "offline", "offset", "on", "option",
"or", "order", "out", "outer", "outputdriver", "outputformat", "over", "overwrite",
"partition", "partitioned", "partitionproperties", "partitions", "percent", "plus",
"preceding", "preserve", "procedure", "purge", "range", "rcfile", "read", "readonly",
"reads", "rebuild", "recordreader", "recordwriter", "reduce", "regexp", "rename",
"repair", "replace", "restrict", "revoke", "right", "rlike", "row", "rows", "schema",
"schemas", "select", "semi", "sequencefile", "serde", "serdeproperties", "set", "shared",
"show", "show_database", "smallint", "sort", "sorted", "ssl", "statistics", "status",
"stored", "streamtable", "string", "struct", "table", "tables", "tablesample",
"tblproperties", "temporary", "terminated", "textfile", "then", "timestamp", "tinyint",
"to", "touch", "transform", "trigger", "true", "type", "unarchive", "unbounded", "undo",
"union", "uniontype", "uniquejoin", "unlock", "unsigned", "update", "use", "using",
"utc", "utc_timestamp", "view", "when", "where", "while", "div",
}

var reservedKeywordsMap map[string]struct{}

//nolint:gochecknoinits
func init() {
reservedKeywordsMap = make(map[string]struct{}, len(reservedKeywords))
for _, kw := range reservedKeywords {
reservedKeywordsMap[kw] = struct{}{}
}
}

func QuoteIdentifier(identifier string) string {
if _, ok := reservedKeywordsMap[strings.ToLower(identifier)]; ok {
return "`" + identifier + "`"
}
return identifier
}
45 changes: 45 additions & 0 deletions ext/store/maxcompute/sanitizer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package maxcompute_test

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/goto/optimus/ext/store/maxcompute"
)

func TestSanitizer(t *testing.T) {
t.Run("returns quoted identifier for reserved keywords", func(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"select", "`select`"},
{"from", "`from`"},
{"case", "`case`"},
{"table", "`table`"},
{"SELECT", "`SELECT`"},
{"From", "`From`"},
}

for _, tc := range testCases {
result := maxcompute.QuoteIdentifier(tc.input)
assert.Equal(t, tc.expected, result)
}
})

t.Run("returns identifier unchanged when not a reserved keyword", func(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{"customer_name", "customer_name"},
{"other", "other"},
}

for _, tc := range testCases {
result := maxcompute.QuoteIdentifier(tc.input)
assert.Equal(t, tc.expected, result)
}
})
}
14 changes: 7 additions & 7 deletions ext/store/maxcompute/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ func populateColumns(t *Table, schemaBuilder *tableschema.SchemaBuilder) error {
func generateUpdateQuery(incoming, existing tableschema.TableSchema, schemaName string) ([]string, error) {
var sqlTasks []string
if incoming.Comment != existing.Comment {
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", schemaName, existing.TableName, common.QuoteString(incoming.Comment)))
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set comment %s;", QuoteIdentifier(schemaName), QuoteIdentifier(existing.TableName), common.QuoteString(incoming.Comment)))
}

if incoming.Lifecycle != existing.Lifecycle {
if incoming.Lifecycle <= 0 && existing.Lifecycle >= 0 {
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", schemaName, existing.TableName))
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s disable lifecycle;", QuoteIdentifier(schemaName), QuoteIdentifier(existing.TableName)))
} else if incoming.Lifecycle > 0 {
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", schemaName, existing.TableName, incoming.Lifecycle))
sqlTasks = append(sqlTasks, fmt.Sprintf("alter table %s.%s set lifecycle %d;", QuoteIdentifier(schemaName), QuoteIdentifier(existing.TableName), incoming.Lifecycle))
}
}

Expand Down Expand Up @@ -257,7 +257,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
if incomingColumnRecord.columnValue.NotNull {
return fmt.Errorf("unable to add new required column")
}
segment := fmt.Sprintf("if not exists %s %s", incomingColumnRecord.columnStructure, incomingColumnRecord.columnValue.Type.Name())
segment := fmt.Sprintf("if not exists %s %s", QuoteIdentifier(incomingColumnRecord.columnStructure), incomingColumnRecord.columnValue.Type.Name())
if incomingColumnRecord.columnValue.Comment != "" {
segment += fmt.Sprintf(" comment %s", common.QuoteString(incomingColumnRecord.columnValue.Comment))
}
Expand All @@ -268,7 +268,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR
if !columnFound.NotNull && incomingColumnRecord.columnValue.NotNull {
return fmt.Errorf("unable to modify column mode from nullable to required")
} else if columnFound.NotNull && !incomingColumnRecord.columnValue.NotNull {
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", schemaName, tableName, columnFound.Name))
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s null;", QuoteIdentifier(schemaName), QuoteIdentifier(tableName), QuoteIdentifier(columnFound.Name)))
}

if columnFound.Type.ID() != incomingColumnRecord.columnValue.Type.ID() {
Expand All @@ -277,7 +277,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR

if incomingColumnRecord.columnValue.Comment != columnFound.Comment {
*sqlTasks = append(*sqlTasks, fmt.Sprintf("alter table %s.%s change column %s %s %s comment %s;",
schemaName, tableName, columnFound.Name, incomingColumnRecord.columnValue.Name, columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment)))
QuoteIdentifier(schemaName), QuoteIdentifier(tableName), QuoteIdentifier(columnFound.Name), QuoteIdentifier(incomingColumnRecord.columnValue.Name), columnFound.Type, common.QuoteString(incomingColumnRecord.columnValue.Comment)))
}
delete(existing, incomingColumnRecord.columnStructure)
}
Expand All @@ -290,7 +290,7 @@ func getNormalColumnDifferences(tableName, schemaName string, incoming []ColumnR

if len(columnAddition) > 0 {
for _, segment := range columnAddition {
addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", schemaName, tableName) + segment + ";"
addColumnQuery := fmt.Sprintf("alter table %s.%s add column ", QuoteIdentifier(schemaName), QuoteIdentifier(tableName)) + segment + ";"
*sqlTasks = append(*sqlTasks, addColumnQuery)
}
}
Expand Down
Loading