optimize if logic and migrator.go ColumnTypes func.

This commit is contained in:
daheige 2021-06-14 10:20:44 +08:00
parent 3226937f68
commit fae2dd4815

View File

@ -2,9 +2,11 @@ package migrator
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"strconv"
"strings" "strings"
"gorm.io/gorm" "gorm.io/gorm"
@ -155,7 +157,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
} }
func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error {
for _, value := range m.ReorderModels(values, false) { results := m.ReorderModels(values, false)
for _, value := range results {
tx := m.DB.Session(&gorm.Session{}) tx := m.DB.Session(&gorm.Session{})
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
var ( var (
@ -206,25 +209,20 @@ func (m Migrator) CreateTable(values ...interface{}) error {
for _, rel := range stmt.Schema.Relationships.Relations { for _, rel := range stmt.Schema.Relationships.Relations {
if !m.DB.DisableForeignKeyConstraintWhenMigrating { if !m.DB.DisableForeignKeyConstraintWhenMigrating {
if constraint := rel.ParseConstraint(); constraint != nil { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Schema == stmt.Schema {
if constraint.Schema == stmt.Schema {
sql, vars := buildConstraint(constraint) sql, vars := buildConstraint(constraint)
createTableSQL += sql + "," createTableSQL += sql + ","
values = append(values, vars...) values = append(values, vars...)
} }
} }
} }
}
for _, chk := range stmt.Schema.ParseCheckConstraints() { for _, chk := range stmt.Schema.ParseCheckConstraints() {
createTableSQL += "CONSTRAINT ? CHECK (?)," createTableSQL += "CONSTRAINT ? CHECK (?),"
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
} }
createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL = strings.TrimSuffix(createTableSQL, ",") + ")"
createTableSQL += ")"
if tableOption, ok := m.DB.Get("gorm:table_options"); ok { if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
createTableSQL += fmt.Sprint(tableOption) createTableSQL += fmt.Sprint(tableOption)
} }
@ -382,30 +380,27 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
alterColumn = true alterColumn = true
} else { } else {
// has size in data type and not equal // has size in data type and not equal
// Since the following code is frequently called in the for loop, reg optimization is needed here // Since the following code is frequently called in the for loop, reg optimization is needed here
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1) matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { if (len(matches) == 1 && matches[0][1] != strconv.Itoa(field.Size) || !field.PrimaryKey) &&
(len(matches2) == 1 && matches2[0][1] != strconv.FormatInt(length, 10)) {
alterColumn = true alterColumn = true
} }
} }
} }
// check precision // check precision
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision &&
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
alterColumn = true alterColumn = true
} }
}
// check nullable // check nullable
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull && !field.PrimaryKey && nullable {
// not primary key & database is nullable // not primary key & database is nullable
if !field.PrimaryKey && nullable {
alterColumn = true alterColumn = true
} }
}
if alterColumn && !field.IgnoreMigration { if alterColumn && !field.IgnoreMigration {
return m.DB.Migrator().AlterColumn(value, field.Name) return m.DB.Migrator().AlterColumn(value, field.Name)
@ -414,22 +409,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
return nil return nil
} }
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes = make([]gorm.ColumnType, 0) columnTypes := make([]gorm.ColumnType, 0)
err = m.RunWithValue(value, func(stmt *gorm.Statement) error { execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error {
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err == nil { if err != nil {
return err
}
defer rows.Close() defer rows.Close()
rawColumnTypes, err := rows.ColumnTypes()
if err == nil { var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes { for _, c := range rawColumnTypes {
columnTypes = append(columnTypes, c) columnTypes = append(columnTypes, c)
} }
}
} return nil
return err
}) })
return
return columnTypes, execErr
} }
func (m Migrator) CreateView(name string, option gorm.ViewOption) error { func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
@ -608,7 +611,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.DB.Exec(createIndexSQL, values...).Error return m.DB.Exec(createIndexSQL, values...).Error
} }
return fmt.Errorf("failed to create index with name %v", name) return fmt.Errorf("failed to create index with name %s", name)
}) })
} }