diff --git a/scope.go b/scope.go index d82cadbc..c037aff8 100644 --- a/scope.go +++ b/scope.go @@ -1278,6 +1278,38 @@ func (scope *Scope) autoMigrate() *Scope { return scope } +// OrderIndexColumns takes in an index name (name) and list of column names (columns) +// and returns a column ordering from the passed in name in the order the column name appears in the index name, +// delimited by an underscore (_) . +func OrderIndexColumns(name string, columns []string) []string { + names := strings.Split(name, "_") + numColumns := len(columns) + outputColumns := make([]string, numColumns) + columnMap := make(map[string]struct{}, numColumns) + for _, column := range columns { + columnMap[column] = struct{}{} + } + i := 0 + for _, name := range names { + if _, exists := columnMap[name]; exists { + outputColumns[i] = name + delete(columnMap, name) + i++ + } + if i == numColumns { + return outputColumns + } + } + // Any remaining columns (improper naming) need to be added + for _, column := range columns { + if _, exists := columnMap[column]; exists { + outputColumns[i] = column + i++ + } + } + return outputColumns +} + func (scope *Scope) autoIndex() *Scope { var indexes = map[string][]string{} var uniqueIndexes = map[string][]string{} @@ -1285,7 +1317,6 @@ func (scope *Scope) autoIndex() *Scope { for _, field := range scope.GetStructFields() { if name, ok := field.TagSettingsGet("INDEX"); ok { names := strings.Split(name, ",") - for _, name := range names { if name == "INDEX" || name == "" { name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) @@ -1297,7 +1328,6 @@ func (scope *Scope) autoIndex() *Scope { if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { names := strings.Split(name, ",") - for _, name := range names { if name == "UNIQUE_INDEX" || name == "" { name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) @@ -1309,13 +1339,15 @@ func (scope *Scope) autoIndex() *Scope { } for name, columns := range indexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + orderedColumns := OrderIndexColumns(name, columns) + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, orderedColumns...); db.Error != nil { scope.db.AddError(db.Error) } } for name, columns := range uniqueIndexes { - if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + orderedColumns := OrderIndexColumns(name, columns) + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, orderedColumns...); db.Error != nil { scope.db.AddError(db.Error) } } diff --git a/scope_test.go b/scope_test.go index f7f1ed08..e98985cf 100644 --- a/scope_test.go +++ b/scope_test.go @@ -2,11 +2,13 @@ package gorm_test import ( "encoding/hex" + "fmt" "math/rand" "strings" "testing" "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" ) func NameIn1And2(d *gorm.DB) *gorm.DB { @@ -91,3 +93,47 @@ func TestDropTableWithTableOptions(t *testing.T) { t.Errorf("Table must be dropped, got error %s", err) } } + +func TestOrderIndexColumns(t *testing.T) { + testcases := []struct { + indexName string + columns []string + expected []string + }{ + { + indexName: "idx_c1_c2_c3", + columns: []string{"c1", "c2", "c3"}, + expected: []string{"c1", "c2", "c3"}, + }, + { + indexName: "c1_c2_c3_idx", + columns: []string{"c1", "c2", "c3"}, + expected: []string{"c1", "c2", "c3"}, + }, + { + indexName: "repositories_c1_c2_c3_idx", + columns: []string{"c1", "c2", "c3"}, + expected: []string{"c1", "c2", "c3"}, + }, + { + indexName: "repositories_c3_c2_c1_idx", + columns: []string{"c1", "c2", "c3"}, + expected: []string{"c3", "c2", "c1"}, + }, + { + indexName: "idxc3_c2_c1", + columns: []string{"c1", "c2", "c3"}, + expected: []string{"c2", "c1", "c3"}, + }, + { + indexName: "idx", + columns: []string{"c1", "c2", "c3"}, + expected: []string{"c1", "c2", "c3"}, + }, + } + for _, test := range testcases { + t.Run(fmt.Sprintf("index: %s", test.indexName), func(t *testing.T) { + assert.Equal(t, test.expected, gorm.OrderIndexColumns(test.indexName, test.columns), "Column output ordering does not match expected result.") + }) + } +}