diff --git a/migrator/migrator.go b/migrator/migrator.go index 4d04d3fe..266cd9b4 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -26,6 +26,8 @@ type Migrator struct { // Config schema config type Config struct { CreateIndexAfterCreateTable bool + // Unique in ColumnType is affected by UniqueIndex, e.g. MySQL + UniqueAffectedByUniqueIndex bool DB *gorm.DB gorm.Dialector } @@ -115,9 +117,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return err } var ( - parseIndexes = stmt.Schema.ParseIndexes() - parseCheckConstraints = stmt.Schema.ParseCheckConstraints() - parseUniqueConstraints = stmt.Schema.ParseUniqueConstraints() + parseIndexes = stmt.Schema.ParseIndexes() + parseCheckConstraints = stmt.Schema.ParseCheckConstraints() ) for _, dbName := range stmt.Schema.DBNames { var foundColumn gorm.ColumnType @@ -157,14 +158,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } } - for _, uni := range parseUniqueConstraints { - if !queryTx.Migrator().HasConstraint(value, uni.Name) { - if err := execTx.Migrator().CreateConstraint(value, uni.Name); err != nil { - return err - } - } - } - for _, chk := range parseCheckConstraints { if !queryTx.Migrator().HasConstraint(value, chk.Name) { if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { @@ -438,6 +431,10 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + if field.IgnoreMigration { + return nil + } + // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) @@ -497,14 +494,6 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - // check unique - if unique, ok := columnType.Unique(); ok && !unique && (field.Unique || field.UniqueIndex) { - // not primary key - if !field.PrimaryKey { - alterColumn = true - } - } - // check default value if !field.PrimaryKey { currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) @@ -533,13 +522,101 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } - if alterColumn && !field.IgnoreMigration { - return m.DB.Migrator().AlterColumn(value, field.DBName) + if alterColumn { + if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil { + return err + } + } + + if err := m.MigrateColumnUnique(value, field, columnType); err != nil { + return err } return nil } +func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { + unique, ok := columnType.Unique() + if !ok || field.PrimaryKey { + return nil // skip primary key + } + + queryTx := m.DB.Session(&gorm.Session{}) + execTx := queryTx + if m.DB.DryRun { + queryTx.DryRun = false + execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) + } + + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) + index := m.DB.NamingStrategy.IndexName(stmt.Table, field.DBName) + + if m.UniqueAffectedByUniqueIndex { + if unique { + hasConstraint := queryTx.Migrator().HasConstraint(value, constraint) + switch { + case field.Unique && !hasConstraint: + if field.Unique { + if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil { + return err + } + } + // field isn't Unique but ColumnType's Unique is reported by UniqueConstraint. + case !field.Unique && hasConstraint: + if err := execTx.Migrator().DropConstraint(value, constraint); err != nil { + return err + } + if field.UniqueIndex { + if err := execTx.Migrator().CreateIndex(value, index); err != nil { + return err + } + } + } + + hasIndex := queryTx.Migrator().HasIndex(value, index) + switch { + case field.UniqueIndex && !hasIndex: + if err := execTx.Migrator().CreateIndex(value, index); err != nil { + return err + } + // field isn't UniqueIndex but ColumnType's Unique is reported by UniqueIndex + case !field.UniqueIndex && hasIndex: + if err := execTx.Migrator().DropIndex(value, index); err != nil { + return err + } + // create normal index + if idx := stmt.Schema.LookIndex(index); idx != nil { + if err := execTx.Migrator().CreateIndex(value, index); err != nil { + return err + } + } + } + } else { + if field.Unique { + if err := execTx.Migrator().CreateConstraint(value, constraint); err != nil { + return err + } + } + if field.UniqueIndex { + if err := execTx.Migrator().CreateIndex(value, index); err != nil { + return err + } + } + return nil + } + } else { + if unique && !field.Unique { + return execTx.Migrator().DropConstraint(value, constraint) + } + if !unique && field.Unique { + return execTx.Migrator().CreateConstraint(value, constraint) + } + } + return nil + }) +} + // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) @@ -610,7 +687,22 @@ func (m Migrator) DropView(name string) error { } // GuessConstraintAndTable guess statement's constraint and it's table based on name -func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { +// +// Deprecated: use GuessConstraintInterfaceAndTable instead. +func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) { + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) + switch c := constraint.(type) { + case *schema.Constraint: + return c, nil, table + case *schema.CheckConstraint: + return nil, c, table + default: + return nil, nil, table + } +} + +// GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name +func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { if stmt.Schema == nil { return nil, stmt.Table } @@ -669,7 +761,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ // CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { vars := []interface{}{clause.Table{Name: table}} if stmt.TableExpr != nil { @@ -685,7 +777,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { // DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - constraint, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { name = constraint.GetName() } @@ -698,7 +790,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - constraint, table := m.GuessConstraintAndTable(stmt, name) + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { name = constraint.GetName() } diff --git a/schema/relationship.go b/schema/relationship.go index 28d7e288..2cce3d41 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -566,7 +566,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu } } -type FKConstraint struct { +// Constraint is ForeignKey Constraint +type Constraint struct { Name string Field *Field Schema *Schema @@ -577,9 +578,9 @@ type FKConstraint struct { OnUpdate string } -func (constraint *FKConstraint) GetName() string { return constraint.Name } +func (constraint *Constraint) GetName() string { return constraint.Name } -func (constraint *FKConstraint) Build() (sql string, vars []interface{}) { +func (constraint *Constraint) Build() (sql string, vars []interface{}) { sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" if constraint.OnDelete != "" { sql += " ON DELETE " + constraint.OnDelete @@ -601,7 +602,7 @@ func (constraint *FKConstraint) Build() (sql string, vars []interface{}) { return } -func (rel *Relationship) ParseConstraint() *FKConstraint { +func (rel *Relationship) ParseConstraint() *Constraint { str := rel.Field.TagSettings["CONSTRAINT"] if str == "-" { return nil @@ -641,7 +642,7 @@ func (rel *Relationship) ParseConstraint() *FKConstraint { name = rel.Schema.namer.RelationshipFKName(*rel) } - constraint := FKConstraint{ + constraint := Constraint{ Name: name, Field: rel.Field, OnUpdate: settings["ONUPDATE"], diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 56b17601..2fcdcd29 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1654,40 +1654,46 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { } // not unique - type UniqueStruct1 struct { - Name string `gorm:"size:20"` - } + type ( + UniqueStruct1 struct { + Name string `gorm:"size:10"` + } + UniqueStruct2 struct { + Name string `gorm:"size:20"` + } + ) checkField(&UniqueStruct1{}, "name", false, false) + checkField(&UniqueStruct2{}, "name", false, false) type ( // unique - UniqueStruct2 struct { - Name string `gorm:"size:20;unique"` - } UniqueStruct3 struct { Name string `gorm:"size:30;unique"` } + UniqueStruct4 struct { + Name string `gorm:"size:40;unique"` + } ) - checkField(&UniqueStruct2{}, "name", true, false) checkField(&UniqueStruct3{}, "name", true, false) + checkField(&UniqueStruct4{}, "name", true, false) type ( // uniqueIndex - UniqueStruct4 struct { - Name string `gorm:"size:40;uniqueIndex"` - } UniqueStruct5 struct { Name string `gorm:"size:50;uniqueIndex"` } UniqueStruct6 struct { - Name string `gorm:"size:20;uniqueIndex:idx_us6_all_names"` - NickName string `gorm:"size:20;uniqueIndex:idx_us6_all_names"` + Name string `gorm:"size:60;uniqueIndex"` + } + UniqueStruct7 struct { + Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` + NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` } ) - checkField(&UniqueStruct4{}, "name", false, true) checkField(&UniqueStruct5{}, "name", false, true) + checkField(&UniqueStruct6{}, "name", false, true) - checkField(&UniqueStruct6{}, "name", false, false) - checkField(&UniqueStruct6{}, "nick_name", false, false) - checkField(&UniqueStruct6{}, "nick_name", false, false) + checkField(&UniqueStruct7{}, "name", false, false) + checkField(&UniqueStruct7{}, "nick_name", false, false) + checkField(&UniqueStruct7{}, "nick_name", false, false) type TestCase struct { name string @@ -1703,36 +1709,46 @@ func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { checkColumnType(t, "name", true) checkIndex(t, nil) } - index := &migrator.Index{ + uniqueIndex := &migrator.Index{ TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, } - if DB.Dialector.Name() != "mysql" { + checkUniqueIndex := func(t *testing.T) { + checkColumnType(t, "name", false) + checkIndex(t, []gorm.Index{uniqueIndex}) + } + if DB.Dialector.Name() == "mysql" { // in mysql, unique equals uniqueIndex checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) - checkIndex(t, []gorm.Index{index}) + checkIndex(t, []gorm.Index{&migrator.Index{ + TableName: table, + NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), + ColumnList: []string{"name"}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + }}) + } + checkUniqueIndex = func(t *testing.T) { + checkColumnType(t, "name", true) + checkIndex(t, []gorm.Index{uniqueIndex}) } - } - checkUniqueIndex := func(t *testing.T) { - checkColumnType(t, "name", false) - checkIndex(t, []gorm.Index{index}) } tests := []TestCase{ - {name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, - {name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkUnique}, - {name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct4{}, checkFunc: checkUniqueIndex}, - {name: "unique to notUnique", from: &UniqueStruct2{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, - {name: "unique to unique", from: &UniqueStruct2{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, - {name: "unique to uniqueIndex", from: &UniqueStruct2{}, to: &UniqueStruct4{}, checkFunc: checkUniqueIndex}, - {name: "uniqueIndex to notUnique", from: &UniqueStruct4{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, - {name: "uniqueIndex to unique", from: &UniqueStruct4{}, to: &UniqueStruct2{}, checkFunc: checkUnique}, - {name: "uniqueIndex to uniqueIndex", from: &UniqueStruct4{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, - {name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct4{}, to: &UniqueStruct6{}, checkFunc: func(t *testing.T) { + {name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + {name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, + {name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique}, + {name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, + {name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, + {name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, + {name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex}, + {name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: func(t *testing.T) { checkColumnType(t, "name", false) checkColumnType(t, "nick_name", false) checkIndex(t, []gorm.Index{&migrator.Index{