Support smart migrate, close #3078
This commit is contained in:
		
							parent
							
								
									3a97639880
								
							
						
					
					
						commit
						cc6a64adfb
					
				| @ -42,6 +42,7 @@ type Migrator interface { | |||||||
| 	AddColumn(dst interface{}, field string) error | 	AddColumn(dst interface{}, field string) error | ||||||
| 	DropColumn(dst interface{}, field string) error | 	DropColumn(dst interface{}, field string) error | ||||||
| 	AlterColumn(dst interface{}, field string) error | 	AlterColumn(dst interface{}, field string) error | ||||||
|  | 	MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error | ||||||
| 	HasColumn(dst interface{}, field string) bool | 	HasColumn(dst interface{}, field string) bool | ||||||
| 	RenameColumn(dst interface{}, oldName, field string) error | 	RenameColumn(dst interface{}, oldName, field string) error | ||||||
| 	ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) | 	ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) | ||||||
|  | |||||||
| @ -5,6 +5,7 @@ import ( | |||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
|  | 	"regexp" | ||||||
| 	"strings" | 	"strings" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm" | 	"gorm.io/gorm" | ||||||
| @ -80,7 +81,6 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { | |||||||
| 
 | 
 | ||||||
| // AutoMigrate
 | // AutoMigrate
 | ||||||
| func (m Migrator) AutoMigrate(values ...interface{}) error { | func (m Migrator) AutoMigrate(values ...interface{}) error { | ||||||
| 	// TODO smart migrate data type
 |  | ||||||
| 	for _, value := range m.ReorderModels(values, true) { | 	for _, value := range m.ReorderModels(values, true) { | ||||||
| 		tx := m.DB.Session(&gorm.Session{}) | 		tx := m.DB.Session(&gorm.Session{}) | ||||||
| 		if !tx.Migrator().HasTable(value) { | 		if !tx.Migrator().HasTable(value) { | ||||||
| @ -89,11 +89,26 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 			} | 			} | ||||||
| 		} else { | 		} else { | ||||||
| 			if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { | 			if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { | ||||||
|  | 				columnTypes, _ := m.DB.Migrator().ColumnTypes(value) | ||||||
|  | 
 | ||||||
| 				for _, field := range stmt.Schema.FieldsByDBName { | 				for _, field := range stmt.Schema.FieldsByDBName { | ||||||
| 					if !tx.Migrator().HasColumn(value, field.DBName) { | 					var foundColumn *sql.ColumnType | ||||||
|  | 
 | ||||||
|  | 					for _, columnType := range columnTypes { | ||||||
|  | 						if columnType.Name() == field.DBName { | ||||||
|  | 							foundColumn = columnType | ||||||
|  | 							break | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 
 | ||||||
|  | 					if foundColumn == nil { | ||||||
|  | 						// not found, add column
 | ||||||
| 						if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { | 						if err := tx.Migrator().AddColumn(value, field.DBName); err != nil { | ||||||
| 							return err | 							return err | ||||||
| 						} | 						} | ||||||
|  | 					} else if err := m.DB.Migrator().MigrateColumn(value, field, foundColumn); err != nil { | ||||||
|  | 						// found, smart migrate
 | ||||||
|  | 						return err | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| @ -120,7 +135,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 				} | 				} | ||||||
| 				return nil | 				return nil | ||||||
| 			}); err != nil { | 			}); err != nil { | ||||||
| 				fmt.Println(err) |  | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| @ -327,6 +341,49 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { | ||||||
|  | 	// found, smart migrate
 | ||||||
|  | 	fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) | ||||||
|  | 	realDataType := strings.ToLower(columnType.DatabaseTypeName()) | ||||||
|  | 
 | ||||||
|  | 	alterColumn := false | ||||||
|  | 
 | ||||||
|  | 	// check size
 | ||||||
|  | 	if length, _ := columnType.Length(); length != int64(field.Size) { | ||||||
|  | 		if length > 0 && field.Size > 0 { | ||||||
|  | 			alterColumn = true | ||||||
|  | 		} else { | ||||||
|  | 			// has size in data type and not equal
 | ||||||
|  | 			matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllString(realDataType, 1) | ||||||
|  | 			matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) | ||||||
|  | 			if len(matches) > 0 && matches[1] != fmt.Sprint(field.Size) || len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) { | ||||||
|  | 				alterColumn = true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// check precision
 | ||||||
|  | 	if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { | ||||||
|  | 		if strings.Contains(fullDataType, fmt.Sprint(field.Precision)) { | ||||||
|  | 			alterColumn = true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// check nullable
 | ||||||
|  | 	if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { | ||||||
|  | 		// not primary key & database is nullable
 | ||||||
|  | 		if !field.PrimaryKey && nullable { | ||||||
|  | 			alterColumn = true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if alterColumn { | ||||||
|  | 		return m.DB.Migrator().AlterColumn(value, field.Name) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { | func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { | ||||||
| 	err = m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	err = m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() | 		rows, err := m.DB.Raw("select * from ?", clause.Table{Name: stmt.Table}).Rows() | ||||||
|  | |||||||
| @ -55,6 +55,7 @@ type Field struct { | |||||||
| 	Comment               string | 	Comment               string | ||||||
| 	Size                  int | 	Size                  int | ||||||
| 	Precision             int | 	Precision             int | ||||||
|  | 	Scale                 int | ||||||
| 	FieldType             reflect.Type | 	FieldType             reflect.Type | ||||||
| 	IndirectFieldType     reflect.Type | 	IndirectFieldType     reflect.Type | ||||||
| 	StructField           reflect.StructField | 	StructField           reflect.StructField | ||||||
| @ -160,6 +161,10 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 		field.Precision, _ = strconv.Atoi(p) | 		field.Precision, _ = strconv.Atoi(p) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	if s, ok := field.TagSettings["SCALE"]; ok { | ||||||
|  | 		field.Scale, _ = strconv.Atoi(s) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { | 	if val, ok := field.TagSettings["NOT NULL"]; ok && utils.CheckTruth(val) { | ||||||
| 		field.NotNull = true | 		field.NotNull = true | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -379,7 +379,6 @@ func (stmt *Statement) Build(clauses ...string) { | |||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	// TODO handle named vars
 |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (stmt *Statement) Parse(value interface{}) (err error) { | func (stmt *Statement) Parse(value interface{}) (err error) { | ||||||
|  | |||||||
| @ -6,11 +6,11 @@ require ( | |||||||
| 	github.com/google/uuid v1.1.1 | 	github.com/google/uuid v1.1.1 | ||||||
| 	github.com/jinzhu/now v1.1.1 | 	github.com/jinzhu/now v1.1.1 | ||||||
| 	github.com/lib/pq v1.6.0 | 	github.com/lib/pq v1.6.0 | ||||||
| 	gorm.io/driver/mysql v0.3.1 | 	gorm.io/driver/mysql v0.3.2 | ||||||
| 	gorm.io/driver/postgres v0.2.6 | 	gorm.io/driver/postgres v0.2.9 | ||||||
| 	gorm.io/driver/sqlite v1.0.9 | 	gorm.io/driver/sqlite v1.0.9 | ||||||
| 	gorm.io/driver/sqlserver v0.2.7 | 	gorm.io/driver/sqlserver v0.2.7 | ||||||
| 	gorm.io/gorm v0.2.19 | 	gorm.io/gorm v0.2.36 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| replace gorm.io/gorm => ../ | replace gorm.io/gorm => ../ | ||||||
|  | |||||||
| @ -47,6 +47,86 @@ func TestMigrate(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func TestSmartMigrateColumn(t *testing.T) { | ||||||
|  | 	type UserMigrateColumn struct { | ||||||
|  | 		ID       uint | ||||||
|  | 		Name     string | ||||||
|  | 		Salary   float64 | ||||||
|  | 		Birthday time.Time | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	DB.Migrator().DropTable(&UserMigrateColumn{}) | ||||||
|  | 
 | ||||||
|  | 	DB.AutoMigrate(&UserMigrateColumn{}) | ||||||
|  | 
 | ||||||
|  | 	type UserMigrateColumn2 struct { | ||||||
|  | 		ID       uint | ||||||
|  | 		Name     string    `gorm:"size:128"` | ||||||
|  | 		Salary   float64   `gorm:"precision:2"` | ||||||
|  | 		Birthday time.Time `gorm:"precision:2"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { | ||||||
|  | 		t.Fatalf("failed to auto migrate, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to get column types, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, columnType := range columnTypes { | ||||||
|  | 		switch columnType.Name() { | ||||||
|  | 		case "name": | ||||||
|  | 			if length, _ := columnType.Length(); length != 0 && length != 128 { | ||||||
|  | 				t.Fatalf("name's length should be 128, but got %v", length) | ||||||
|  | 			} | ||||||
|  | 		case "salary": | ||||||
|  | 			if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { | ||||||
|  | 				t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) | ||||||
|  | 			} | ||||||
|  | 		case "birthday": | ||||||
|  | 			if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { | ||||||
|  | 				t.Fatalf("birthday's precision should be 2, but got %v", precision) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	type UserMigrateColumn3 struct { | ||||||
|  | 		ID       uint | ||||||
|  | 		Name     string    `gorm:"size:256"` | ||||||
|  | 		Salary   float64   `gorm:"precision:3"` | ||||||
|  | 		Birthday time.Time `gorm:"precision:3"` | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { | ||||||
|  | 		t.Fatalf("failed to auto migrate, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to get column types, got error: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, columnType := range columnTypes { | ||||||
|  | 		switch columnType.Name() { | ||||||
|  | 		case "name": | ||||||
|  | 			if length, _ := columnType.Length(); length != 0 && length != 256 { | ||||||
|  | 				t.Fatalf("name's length should be 128, but got %v", length) | ||||||
|  | 			} | ||||||
|  | 		case "salary": | ||||||
|  | 			if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { | ||||||
|  | 				t.Fatalf("salary's precision should be 2, but got %v", precision) | ||||||
|  | 			} | ||||||
|  | 		case "birthday": | ||||||
|  | 			if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { | ||||||
|  | 				t.Fatalf("birthday's precision should be 2, but got %v", precision) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func TestMigrateWithComment(t *testing.T) { | func TestMigrateWithComment(t *testing.T) { | ||||||
| 	type UserWithComment struct { | 	type UserWithComment struct { | ||||||
| 		gorm.Model | 		gorm.Model | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu