Allow multicolumn index to specify the order in the annotation
This commit is contained in:
parent
79a77d771d
commit
0fa63da712
40
scope.go
40
scope.go
@ -1278,6 +1278,38 @@ func (scope *Scope) autoMigrate() *Scope {
|
|||||||
return scope
|
return scope
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OrderIndexColumns takes in an index name (name) and list of column names (columns)
|
||||||
|
// and returns a column ordering from the passed in name in the order the column name appears in the index name,
|
||||||
|
// delimited by an underscore (_) .
|
||||||
|
func OrderIndexColumns(name string, columns []string) []string {
|
||||||
|
names := strings.Split(name, "_")
|
||||||
|
numColumns := len(columns)
|
||||||
|
outputColumns := make([]string, numColumns)
|
||||||
|
columnMap := make(map[string]struct{}, numColumns)
|
||||||
|
for _, column := range columns {
|
||||||
|
columnMap[column] = struct{}{}
|
||||||
|
}
|
||||||
|
i := 0
|
||||||
|
for _, name := range names {
|
||||||
|
if _, exists := columnMap[name]; exists {
|
||||||
|
outputColumns[i] = name
|
||||||
|
delete(columnMap, name)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i == numColumns {
|
||||||
|
return outputColumns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Any remaining columns (improper naming) need to be added
|
||||||
|
for _, column := range columns {
|
||||||
|
if _, exists := columnMap[column]; exists {
|
||||||
|
outputColumns[i] = column
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputColumns
|
||||||
|
}
|
||||||
|
|
||||||
func (scope *Scope) autoIndex() *Scope {
|
func (scope *Scope) autoIndex() *Scope {
|
||||||
var indexes = map[string][]string{}
|
var indexes = map[string][]string{}
|
||||||
var uniqueIndexes = map[string][]string{}
|
var uniqueIndexes = map[string][]string{}
|
||||||
@ -1285,7 +1317,6 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
for _, field := range scope.GetStructFields() {
|
for _, field := range scope.GetStructFields() {
|
||||||
if name, ok := field.TagSettingsGet("INDEX"); ok {
|
if name, ok := field.TagSettingsGet("INDEX"); ok {
|
||||||
names := strings.Split(name, ",")
|
names := strings.Split(name, ",")
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if name == "INDEX" || name == "" {
|
if name == "INDEX" || name == "" {
|
||||||
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName)
|
||||||
@ -1297,7 +1328,6 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
|
|
||||||
if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok {
|
if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok {
|
||||||
names := strings.Split(name, ",")
|
names := strings.Split(name, ",")
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if name == "UNIQUE_INDEX" || name == "" {
|
if name == "UNIQUE_INDEX" || name == "" {
|
||||||
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName)
|
||||||
@ -1309,13 +1339,15 @@ func (scope *Scope) autoIndex() *Scope {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for name, columns := range indexes {
|
for name, columns := range indexes {
|
||||||
if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil {
|
orderedColumns := OrderIndexColumns(name, columns)
|
||||||
|
if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, orderedColumns...); db.Error != nil {
|
||||||
scope.db.AddError(db.Error)
|
scope.db.AddError(db.Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, columns := range uniqueIndexes {
|
for name, columns := range uniqueIndexes {
|
||||||
if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil {
|
orderedColumns := OrderIndexColumns(name, columns)
|
||||||
|
if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, orderedColumns...); db.Error != nil {
|
||||||
scope.db.AddError(db.Error)
|
scope.db.AddError(db.Error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,13 @@ package gorm_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/jinzhu/gorm"
|
"github.com/jinzhu/gorm"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
func NameIn1And2(d *gorm.DB) *gorm.DB {
|
||||||
@ -91,3 +93,47 @@ func TestDropTableWithTableOptions(t *testing.T) {
|
|||||||
t.Errorf("Table must be dropped, got error %s", err)
|
t.Errorf("Table must be dropped, got error %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOrderIndexColumns(t *testing.T) {
|
||||||
|
testcases := []struct {
|
||||||
|
indexName string
|
||||||
|
columns []string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
indexName: "idx_c1_c2_c3",
|
||||||
|
columns: []string{"c1", "c2", "c3"},
|
||||||
|
expected: []string{"c1", "c2", "c3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
indexName: "c1_c2_c3_idx",
|
||||||
|
columns: []string{"c1", "c2", "c3"},
|
||||||
|
expected: []string{"c1", "c2", "c3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
indexName: "repositories_c1_c2_c3_idx",
|
||||||
|
columns: []string{"c1", "c2", "c3"},
|
||||||
|
expected: []string{"c1", "c2", "c3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
indexName: "repositories_c3_c2_c1_idx",
|
||||||
|
columns: []string{"c1", "c2", "c3"},
|
||||||
|
expected: []string{"c3", "c2", "c1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
indexName: "idxc3_c2_c1",
|
||||||
|
columns: []string{"c1", "c2", "c3"},
|
||||||
|
expected: []string{"c2", "c1", "c3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
indexName: "idx",
|
||||||
|
columns: []string{"c1", "c2", "c3"},
|
||||||
|
expected: []string{"c1", "c2", "c3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, test := range testcases {
|
||||||
|
t.Run(fmt.Sprintf("index: %s", test.indexName), func(t *testing.T) {
|
||||||
|
assert.Equal(t, test.expected, gorm.OrderIndexColumns(test.indexName, test.columns), "Column output ordering does not match expected result.")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user