fix: migrator run with nil schema

This commit is contained in:
black 2023-05-10 09:51:43 +08:00
parent a1b7e47e75
commit a0c4a21718
2 changed files with 74 additions and 44 deletions

View File

@ -55,8 +55,6 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
if table, ok := value.(string); ok { if table, ok := value.(string); ok {
stmt.Table = table stmt.Table = table
// set schema to avoid panic
stmt.Schema = &schema.Schema{}
} else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil {
return err return err
} }
@ -347,7 +345,10 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
func (m Migrator) AddColumn(value interface{}, name string) error { func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
// avoid using the same name field // avoid using the same name field
f := stmt.Schema.LookUpField(name) var f *schema.Field
if stmt.Schema != nil {
f = stmt.Schema.LookUpField(name)
}
if f == nil { if f == nil {
return fmt.Errorf("failed to look up field with name: %s", name) return fmt.Errorf("failed to look up field with name: %s", name)
} }
@ -366,8 +367,10 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
// DropColumn drop value's `name` column // DropColumn drop value's `name` column
func (m Migrator) DropColumn(value interface{}, name string) error { func (m Migrator) DropColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(name); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -379,13 +382,14 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
// AlterColumn alter value's `field` column' type based on schema definition // AlterColumn alter value's `field` column' type based on schema definition
func (m Migrator) AlterColumn(value interface{}, field string) error { func (m Migrator) AlterColumn(value interface{}, field string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
fileType := m.FullDataTypeOf(field) if field := stmt.Schema.LookUpField(field); field != nil {
return m.DB.Exec( fileType := m.FullDataTypeOf(field)
"ALTER TABLE ? ALTER COLUMN ? TYPE ?", return m.DB.Exec(
m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, "ALTER TABLE ? ALTER COLUMN ? TYPE ?",
).Error m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType,
).Error
}
} }
return fmt.Errorf("failed to look up field with name: %s", field) return fmt.Errorf("failed to look up field with name: %s", field)
}) })
@ -397,8 +401,10 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field name := field
if field := stmt.Schema.LookUpField(field); field != nil { if stmt.Schema != nil {
name = field.DBName if field := stmt.Schema.LookUpField(field); field != nil {
name = field.DBName
}
} }
return m.DB.Raw( return m.DB.Raw(
@ -413,12 +419,14 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
// RenameColumn rename value's field name from oldName to newName // RenameColumn rename value's field name from oldName to newName
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if field := stmt.Schema.LookUpField(oldName); field != nil { if stmt.Schema != nil {
oldName = field.DBName if field := stmt.Schema.LookUpField(oldName); field != nil {
} oldName = field.DBName
}
if field := stmt.Schema.LookUpField(newName); field != nil { if field := stmt.Schema.LookUpField(newName); field != nil {
newName = field.DBName newName = field.DBName
}
} }
return m.DB.Exec( return m.DB.Exec(
@ -756,29 +764,31 @@ type BuildIndexOptionsInterface interface {
// CreateIndex create index `name` // CreateIndex create index `name`
func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) if idx := stmt.Schema.LookIndex(name); idx != nil {
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts}
createIndexSQL := "CREATE " createIndexSQL := "CREATE "
if idx.Class != "" { if idx.Class != "" {
createIndexSQL += idx.Class + " " createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ? ON ??"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
if idx.Comment != "" {
createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createIndexSQL += " " + idx.Option
}
return m.DB.Exec(createIndexSQL, values...).Error
} }
createIndexSQL += "INDEX ? ON ??"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
if idx.Comment != "" {
createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment)
}
if idx.Option != "" {
createIndexSQL += " " + idx.Option
}
return m.DB.Exec(createIndexSQL, values...).Error
} }
return fmt.Errorf("failed to create index with name %s", name) return fmt.Errorf("failed to create index with name %s", name)
@ -788,8 +798,10 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
// DropIndex drop index `name` // DropIndex drop index `name`
func (m Migrator) DropIndex(value interface{}, name string) error { func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error
@ -801,8 +813,10 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int64 var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error { m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase() currentDatabase := m.DB.Migrator().CurrentDatabase()
if idx := stmt.Schema.LookIndex(name); idx != nil { if stmt.Schema != nil {
name = idx.Name if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
} }
return m.DB.Raw( return m.DB.Raw(

View File

@ -384,6 +384,22 @@ func TestMigrateIndexes(t *testing.T) {
if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") {
t.Fatalf("Should not find index for user's name after delete") t.Fatalf("Should not find index for user's name after delete")
} }
if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil {
t.Fatalf("Got error when tried to create index: %+v", err)
}
if err := DB.Migrator().RenameIndex("index_structs", "idx_index_structs_name", "idx_users_name_1"); err != nil {
t.Fatalf("no error should happen when rename index, but got %v", err)
}
if !DB.Migrator().HasIndex("index_structs", "idx_users_name_1") {
t.Fatalf("Should find index for user's name after rename")
}
if err := DB.Migrator().DropIndex("index_structs", "idx_users_name_1"); err != nil {
t.Fatalf("Failed to drop index for user's name, got err %v", err)
}
} }
func TestTiDBMigrateColumns(t *testing.T) { func TestTiDBMigrateColumns(t *testing.T) {