optimize if logic and migrator.go ColumnTypes func.
This commit is contained in:
parent
3226937f68
commit
fae2dd4815
@ -2,9 +2,11 @@ package migrator
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -155,7 +157,8 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) CreateTable(values ...interface{}) error {
|
func (m Migrator) CreateTable(values ...interface{}) error {
|
||||||
for _, value := range m.ReorderModels(values, false) {
|
results := m.ReorderModels(values, false)
|
||||||
|
for _, value := range results {
|
||||||
tx := m.DB.Session(&gorm.Session{})
|
tx := m.DB.Session(&gorm.Session{})
|
||||||
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) {
|
||||||
var (
|
var (
|
||||||
@ -206,25 +209,20 @@ func (m Migrator) CreateTable(values ...interface{}) error {
|
|||||||
|
|
||||||
for _, rel := range stmt.Schema.Relationships.Relations {
|
for _, rel := range stmt.Schema.Relationships.Relations {
|
||||||
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
if !m.DB.DisableForeignKeyConstraintWhenMigrating {
|
||||||
if constraint := rel.ParseConstraint(); constraint != nil {
|
if constraint := rel.ParseConstraint(); constraint != nil && constraint.Schema == stmt.Schema {
|
||||||
if constraint.Schema == stmt.Schema {
|
|
||||||
sql, vars := buildConstraint(constraint)
|
sql, vars := buildConstraint(constraint)
|
||||||
createTableSQL += sql + ","
|
createTableSQL += sql + ","
|
||||||
values = append(values, vars...)
|
values = append(values, vars...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
for _, chk := range stmt.Schema.ParseCheckConstraints() {
|
||||||
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
createTableSQL += "CONSTRAINT ? CHECK (?),"
|
||||||
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint})
|
||||||
}
|
}
|
||||||
|
|
||||||
createTableSQL = strings.TrimSuffix(createTableSQL, ",")
|
createTableSQL = strings.TrimSuffix(createTableSQL, ",") + ")"
|
||||||
|
|
||||||
createTableSQL += ")"
|
|
||||||
|
|
||||||
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
|
if tableOption, ok := m.DB.Get("gorm:table_options"); ok {
|
||||||
createTableSQL += fmt.Sprint(tableOption)
|
createTableSQL += fmt.Sprint(tableOption)
|
||||||
}
|
}
|
||||||
@ -382,30 +380,27 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
alterColumn = true
|
alterColumn = true
|
||||||
} else {
|
} else {
|
||||||
// has size in data type and not equal
|
// has size in data type and not equal
|
||||||
|
|
||||||
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
// Since the following code is frequently called in the for loop, reg optimization is needed here
|
||||||
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
|
matches := regRealDataType.FindAllStringSubmatch(realDataType, -1)
|
||||||
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1)
|
||||||
if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) {
|
if (len(matches) == 1 && matches[0][1] != strconv.Itoa(field.Size) || !field.PrimaryKey) &&
|
||||||
|
(len(matches2) == 1 && matches2[0][1] != strconv.FormatInt(length, 10)) {
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check precision
|
// check precision
|
||||||
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision {
|
if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision &&
|
||||||
if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) {
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// check nullable
|
// check nullable
|
||||||
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull {
|
if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull && !field.PrimaryKey && nullable {
|
||||||
// not primary key & database is nullable
|
// not primary key & database is nullable
|
||||||
if !field.PrimaryKey && nullable {
|
|
||||||
alterColumn = true
|
alterColumn = true
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if alterColumn && !field.IgnoreMigration {
|
if alterColumn && !field.IgnoreMigration {
|
||||||
return m.DB.Migrator().AlterColumn(value, field.Name)
|
return m.DB.Migrator().AlterColumn(value, field.Name)
|
||||||
@ -414,22 +409,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) {
|
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||||
columnTypes = make([]gorm.ColumnType, 0)
|
columnTypes := make([]gorm.ColumnType, 0)
|
||||||
err = m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
rawColumnTypes, err := rows.ColumnTypes()
|
|
||||||
if err == nil {
|
var rawColumnTypes []*sql.ColumnType
|
||||||
|
rawColumnTypes, err = rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, c := range rawColumnTypes {
|
for _, c := range rawColumnTypes {
|
||||||
columnTypes = append(columnTypes, c)
|
columnTypes = append(columnTypes, c)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
return nil
|
||||||
return err
|
|
||||||
})
|
})
|
||||||
return
|
|
||||||
|
return columnTypes, execErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
|
||||||
@ -608,7 +611,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
|
|||||||
return m.DB.Exec(createIndexSQL, values...).Error
|
return m.DB.Exec(createIndexSQL, values...).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("failed to create index with name %v", name)
|
return fmt.Errorf("failed to create index with name %s", name)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user