diff --git a/migrator/migrator.go b/migrator/migrator.go index 03ffdd02..da9ddd49 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -2,9 +2,11 @@ package migrator import ( "context" + "database/sql" "fmt" "reflect" "regexp" + "strconv" "strings" "gorm.io/gorm" @@ -155,7 +157,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { } func (m Migrator) CreateTable(values ...interface{}) error { - for _, value := range m.ReorderModels(values, false) { + results := m.ReorderModels(values, false) + for _, value := range results { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( @@ -206,12 +209,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, rel := range stmt.Schema.Relationships.Relations { if !m.DB.DisableForeignKeyConstraintWhenMigrating { - if constraint := rel.ParseConstraint(); constraint != nil { - if constraint.Schema == stmt.Schema { - sql, vars := buildConstraint(constraint) - createTableSQL += sql + "," - values = append(values, vars...) - } + if constraint := rel.ParseConstraint(); constraint != nil && constraint.Schema == stmt.Schema { + sql, vars := buildConstraint(constraint) + createTableSQL += sql + "," + values = append(values, vars...) } } } @@ -221,10 +222,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) } - createTableSQL = strings.TrimSuffix(createTableSQL, ",") - - createTableSQL += ")" - + createTableSQL = strings.TrimSuffix(createTableSQL, ",") + ")" if tableOption, ok := m.DB.Get("gorm:table_options"); ok { createTableSQL += fmt.Sprint(tableOption) } @@ -382,29 +380,26 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - // Since the following code is frequently called in the for loop, reg optimization is needed here matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + if (len(matches) == 1 && matches[0][1] != strconv.Itoa(field.Size) || !field.PrimaryKey) && + (len(matches2) == 1 && matches2[0][1] != strconv.FormatInt(length, 10)) { alterColumn = true } } } // check precision - if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { - if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { - alterColumn = true - } + if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision && + regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { + alterColumn = true } // check nullable - if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { + if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull && !field.PrimaryKey && nullable { // not primary key & database is nullable - if !field.PrimaryKey && nullable { - alterColumn = true - } + alterColumn = true } if alterColumn && !field.IgnoreMigration { @@ -414,22 +409,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy return nil } -func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { - columnTypes = make([]gorm.ColumnType, 0) - err = m.RunWithValue(value, func(stmt *gorm.Statement) error { +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { + columnTypes := make([]gorm.ColumnType, 0) + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() - if err == nil { - defer rows.Close() - rawColumnTypes, err := rows.ColumnTypes() - if err == nil { - for _, c := range rawColumnTypes { - columnTypes = append(columnTypes, c) - } - } + if err != nil { + return err } - return err + + defer rows.Close() + + var rawColumnTypes []*sql.ColumnType + rawColumnTypes, err = rows.ColumnTypes() + if err != nil { + return err + } + + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, c) + } + + return nil }) - return + + return columnTypes, execErr } func (m Migrator) CreateView(name string, option gorm.ViewOption) error { @@ -608,7 +611,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { return m.DB.Exec(createIndexSQL, values...).Error } - return fmt.Errorf("failed to create index with name %v", name) + return fmt.Errorf("failed to create index with name %s", name) }) }