From 1529430536c23ae311091d7acdfea4609bb33375 Mon Sep 17 00:00:00 2001 From: black Date: Wed, 14 Jun 2023 21:08:11 +0800 Subject: [PATCH] fix MigrateColumnUnique --- migrator/migrator.go | 46 +++++++++++++++++++++++-------------------- schema/field.go | 2 +- schema/index.go | 2 +- schema/index_test.go | 12 +++++------ tests/migrate_test.go | 34 +++++++++++++++++++++++++++----- 5 files changed, 62 insertions(+), 34 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 266cd9b4..f8b975b7 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -549,11 +549,28 @@ func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, co } return m.RunWithValue(value, func(stmt *gorm.Statement) error { + // We're currently only receiving boolean values on `Unique` tag, + // so the UniqueConstraint name is fixed constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) - index := m.DB.NamingStrategy.IndexName(stmt.Table, field.DBName) - if m.UniqueAffectedByUniqueIndex { if unique { + // Clean up redundant unique indexes + indexes, _ := queryTx.Migrator().GetIndexes(value) + for _, index := range indexes { + if uni, ok := index.Unique(); !ok || !uni { + continue + } + if columns := index.Columns(); len(columns) != 1 || columns[0] != field.DBName { + continue + } + if name := index.Name(); name == constraint || name == field.UniqueIndex { + continue + } + if err := execTx.Migrator().DropIndex(value, index.Name()); err != nil { + return err + } + } + hasConstraint := queryTx.Migrator().HasConstraint(value, constraint) switch { case field.Unique && !hasConstraint: @@ -567,30 +584,17 @@ func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, co if err := execTx.Migrator().DropConstraint(value, constraint); err != nil { return err } - if field.UniqueIndex { - if err := execTx.Migrator().CreateIndex(value, index); err != nil { + if field.UniqueIndex != "" { + if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil { return err } } } - hasIndex := queryTx.Migrator().HasIndex(value, index) - switch { - case field.UniqueIndex && !hasIndex: - if err := execTx.Migrator().CreateIndex(value, index); err != nil { + if field.UniqueIndex != "" && !queryTx.Migrator().HasIndex(value, field.UniqueIndex) { + if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil { return err } - // field isn't UniqueIndex but ColumnType's Unique is reported by UniqueIndex - case !field.UniqueIndex && hasIndex: - if err := execTx.Migrator().DropIndex(value, index); err != nil { - return err - } - // create normal index - if idx := stmt.Schema.LookIndex(index); idx != nil { - if err := execTx.Migrator().CreateIndex(value, index); err != nil { - return err - } - } } } else { if field.Unique { @@ -598,8 +602,8 @@ func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, co return err } } - if field.UniqueIndex { - if err := execTx.Migrator().CreateIndex(value, index); err != nil { + if field.UniqueIndex != "" { + if err := execTx.Migrator().CreateIndex(value, field.UniqueIndex); err != nil { return err } } diff --git a/schema/field.go b/schema/field.go index c25a7ec0..8df568ab 100644 --- a/schema/field.go +++ b/schema/field.go @@ -69,7 +69,7 @@ type Field struct { DefaultValueInterface interface{} NotNull bool Unique bool - UniqueIndex bool + UniqueIndex string Comment string Size int Precision int diff --git a/schema/index.go b/schema/index.go index a99bfde4..f4f36751 100644 --- a/schema/index.go +++ b/schema/index.go @@ -67,7 +67,7 @@ func (schema *Schema) ParseIndexes() map[string]Index { } for _, index := range indexes { if index.Class == "UNIQUE" && len(index.Fields) == 1 { - index.Fields[0].Field.UniqueIndex = true + index.Fields[0].Field.UniqueIndex = index.Name } } return indexes diff --git a/schema/index_test.go b/schema/index_test.go index 3983a7bc..2f1e36af 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -66,7 +66,7 @@ func TestParseIndex(t *testing.T) { "idx_name": { Name: "idx_name", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}}, }, "idx_user_indices_name3": { Name: "idx_user_indices_name3", @@ -82,7 +82,7 @@ func TestParseIndex(t *testing.T) { "idx_user_indices_name4": { Name: "idx_user_indices_name4", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}}, }, "idx_user_indices_name5": { Name: "idx_user_indices_name5", @@ -103,12 +103,12 @@ func TestParseIndex(t *testing.T) { }, "idx_id": { Name: "idx_id", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "idx_oid": { Name: "idx_oid", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, "type": { Name: "type", @@ -191,7 +191,7 @@ func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) { "idx_index_tests_field_c": { Name: "idx_index_tests_field_c", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}}, }, "idx_index_tests_field_d": { Name: "idx_index_tests_field_d", @@ -225,7 +225,7 @@ func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) { "idx_index_tests_field_g": { Name: "idx_index_tests_field_g", Class: "UNIQUE", - Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: true}}}, + Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}}, }, "uniq_field_h1_h2": { Name: "uniq_field_h1_h2", diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2fcdcd29..33a6d9c4 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -15,6 +15,7 @@ import ( "gorm.io/driver/postgres" "gorm.io/gorm" + "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" @@ -925,7 +926,8 @@ func TestCurrentTimestamp(t *testing.T) { if err != nil { t.Fatalf("AutoMigrate err:%v", err) } - AssertEqual(t, true, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) + AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at")) + AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) } @@ -987,7 +989,8 @@ func TestUniqueColumn(t *testing.T) { } // not trigger alert column - AssertEqual(t, true, DB.Migrator().HasIndex(&UniqueTest{}, "name")) + AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name")) + AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) @@ -1648,13 +1651,12 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { if field.Unique != unique { t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique) } - if field.UniqueIndex != uniqueIndex { + if (field.UniqueIndex != "") != uniqueIndex { t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex) } } - // not unique - type ( + type ( // not unique UniqueStruct1 struct { Name string `gorm:"size:10"` } @@ -1774,4 +1776,26 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { test.checkFunc(t) }) } + + if DB.Dialector.Name() == "mysql" { + compatibilityTests := []TestCase{ + {name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique}, + {name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex}, + } + for _, test := range compatibilityTests { + t.Run(test.name, func(t *testing.T) { + if err := DB.Migrator().DropTable(table); err != nil { + t.Fatalf("failed to drop table, got error: %v", err) + } + if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil { + t.Fatalf("failed to create table, got error: %v", err) + } + if err := DB.Debug().Table(table).AutoMigrate(test.to); err != nil { + t.Fatalf("failed to migrate table, got error: %v", err) + } + test.checkFunc(t) + }) + } + } }