diff --git a/migrator/migrator.go b/migrator/migrator.go index 39d8426f..7461441d 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -55,8 +55,6 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error if table, ok := value.(string); ok { stmt.Table = table - // set schema to avoid panic - stmt.Schema = &schema.Schema{} } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { return err } @@ -347,7 +345,10 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { // avoid using the same name field - f := stmt.Schema.LookUpField(name) + var f *schema.Field + if stmt.Schema != nil { + f = stmt.Schema.LookUpField(name) + } if f == nil { return fmt.Errorf("failed to look up field with name: %s", name) } @@ -366,8 +367,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error { // DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(name); field != nil { - name = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } } return m.DB.Exec( @@ -379,13 +382,14 @@ func (m Migrator) DropColumn(value interface{}, name string) error { // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(field); field != nil { - fileType := m.FullDataTypeOf(field) - return m.DB.Exec( - "ALTER TABLE ? ALTER COLUMN ? TYPE ?", - m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, - ).Error - + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + fileType := m.FullDataTypeOf(field) + return m.DB.Exec( + "ALTER TABLE ? ALTER COLUMN ? TYPE ?", + m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, + ).Error + } } return fmt.Errorf("failed to look up field with name: %s", field) }) @@ -397,8 +401,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() name := field - if field := stmt.Schema.LookUpField(field); field != nil { - name = field.DBName + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(field); field != nil { + name = field.DBName + } } return m.DB.Raw( @@ -413,12 +419,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { // RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if field := stmt.Schema.LookUpField(oldName); field != nil { - oldName = field.DBName - } + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(oldName); field != nil { + oldName = field.DBName + } - if field := stmt.Schema.LookUpField(newName); field != nil { - newName = field.DBName + if field := stmt.Schema.LookUpField(newName); field != nil { + newName = field.DBName + } } return m.DB.Exec( @@ -756,29 +764,31 @@ type BuildIndexOptionsInterface interface { // CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) - values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} - createIndexSQL := "CREATE " - if idx.Class != "" { - createIndexSQL += idx.Class + " " + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ? ON ??" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + + if idx.Comment != "" { + createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) + } + + if idx.Option != "" { + createIndexSQL += " " + idx.Option + } + + return m.DB.Exec(createIndexSQL, values...).Error } - createIndexSQL += "INDEX ? ON ??" - - if idx.Type != "" { - createIndexSQL += " USING " + idx.Type - } - - if idx.Comment != "" { - createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) - } - - if idx.Option != "" { - createIndexSQL += " " + idx.Option - } - - return m.DB.Exec(createIndexSQL, values...).Error } return fmt.Errorf("failed to create index with name %s", name) @@ -788,8 +798,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { // DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error @@ -801,8 +813,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() - if idx := stmt.Schema.LookIndex(name); idx != nil { - name = idx.Name + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } } return m.DB.Raw( diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 0d617566..7ad22de6 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -384,6 +384,22 @@ func TestMigrateIndexes(t *testing.T) { if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { t.Fatalf("Should not find index for user's name after delete") } + + if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { + t.Fatalf("Got error when tried to create index: %+v", err) + } + + if err := DB.Migrator().RenameIndex("index_structs", "idx_index_structs_name", "idx_users_name_1"); err != nil { + t.Fatalf("no error should happen when rename index, but got %v", err) + } + + if !DB.Migrator().HasIndex("index_structs", "idx_users_name_1") { + t.Fatalf("Should find index for user's name after rename") + } + + if err := DB.Migrator().DropIndex("index_structs", "idx_users_name_1"); err != nil { + t.Fatalf("Failed to drop index for user's name, got err %v", err) + } } func TestTiDBMigrateColumns(t *testing.T) {