Enhance migrator Columntype interface (#5088)
* Update Migrator ColumnType interface * Update MigrateColumn Test * Upgrade test drivers * Fix typo
This commit is contained in:
		
							parent
							
								
									39d84cba5f
								
							
						
					
					
						commit
						0af95f509a
					
				
							
								
								
									
										13
									
								
								migrator.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								migrator.go
									
									
									
									
									
								
							| @ -1,6 +1,8 @@ | |||||||
| package gorm | package gorm | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 
 | ||||||
| 	"gorm.io/gorm/clause" | 	"gorm.io/gorm/clause" | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| ) | ) | ||||||
| @ -33,14 +35,23 @@ type ViewOption struct { | |||||||
| 	Query       *DB | 	Query       *DB | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // ColumnType column type interface
 | ||||||
| type ColumnType interface { | type ColumnType interface { | ||||||
| 	Name() string | 	Name() string | ||||||
| 	DatabaseTypeName() string | 	DatabaseTypeName() string                 // varchar
 | ||||||
|  | 	ColumnType() (columnType string, ok bool) // varchar(64)
 | ||||||
|  | 	PrimaryKey() (isPrimaryKey bool, ok bool) | ||||||
|  | 	AutoIncrement() (isAutoIncrement bool, ok bool) | ||||||
| 	Length() (length int64, ok bool) | 	Length() (length int64, ok bool) | ||||||
| 	DecimalSize() (precision int64, scale int64, ok bool) | 	DecimalSize() (precision int64, scale int64, ok bool) | ||||||
| 	Nullable() (nullable bool, ok bool) | 	Nullable() (nullable bool, ok bool) | ||||||
|  | 	Unique() (unique bool, ok bool) | ||||||
|  | 	ScanType() reflect.Type | ||||||
|  | 	Comment() (value string, ok bool) | ||||||
|  | 	DefaultValue() (value string, ok bool) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // Migrator migrator interface
 | ||||||
| type Migrator interface { | type Migrator interface { | ||||||
| 	// AutoMigrate
 | 	// AutoMigrate
 | ||||||
| 	AutoMigrate(dst ...interface{}) error | 	AutoMigrate(dst ...interface{}) error | ||||||
|  | |||||||
							
								
								
									
										107
									
								
								migrator/column_type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								migrator/column_type.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,107 @@ | |||||||
|  | package migrator | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"reflect" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // ColumnType column type implements ColumnType interface
 | ||||||
|  | type ColumnType struct { | ||||||
|  | 	SQLColumnType      *sql.ColumnType | ||||||
|  | 	NameValue          sql.NullString | ||||||
|  | 	DataTypeValue      sql.NullString | ||||||
|  | 	ColumnTypeValue    sql.NullString | ||||||
|  | 	PrimayKeyValue     sql.NullBool | ||||||
|  | 	UniqueValue        sql.NullBool | ||||||
|  | 	AutoIncrementValue sql.NullBool | ||||||
|  | 	LengthValue        sql.NullInt64 | ||||||
|  | 	DecimalSizeValue   sql.NullInt64 | ||||||
|  | 	ScaleValue         sql.NullInt64 | ||||||
|  | 	NullableValue      sql.NullBool | ||||||
|  | 	ScanTypeValue      reflect.Type | ||||||
|  | 	CommentValue       sql.NullString | ||||||
|  | 	DefaultValueValue  sql.NullString | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Name returns the name or alias of the column.
 | ||||||
|  | func (ct ColumnType) Name() string { | ||||||
|  | 	if ct.NameValue.Valid { | ||||||
|  | 		return ct.NameValue.String | ||||||
|  | 	} | ||||||
|  | 	return ct.SQLColumnType.Name() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DatabaseTypeName returns the database system name of the column type. If an empty
 | ||||||
|  | // string is returned, then the driver type name is not supported.
 | ||||||
|  | // Consult your driver documentation for a list of driver data types. Length specifiers
 | ||||||
|  | // are not included.
 | ||||||
|  | // Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
 | ||||||
|  | // "INT", and "BIGINT".
 | ||||||
|  | func (ct ColumnType) DatabaseTypeName() string { | ||||||
|  | 	if ct.DataTypeValue.Valid { | ||||||
|  | 		return ct.DataTypeValue.String | ||||||
|  | 	} | ||||||
|  | 	return ct.SQLColumnType.DatabaseTypeName() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ColumnType returns the database type of the column. lke `varchar(16)`
 | ||||||
|  | func (ct ColumnType) ColumnType() (columnType string, ok bool) { | ||||||
|  | 	return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // PrimaryKey returns the column is primary key or not.
 | ||||||
|  | func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { | ||||||
|  | 	return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // AutoIncrement returns the column is auto increment or not.
 | ||||||
|  | func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { | ||||||
|  | 	return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Length returns the column type length for variable length column types
 | ||||||
|  | func (ct ColumnType) Length() (length int64, ok bool) { | ||||||
|  | 	if ct.LengthValue.Valid { | ||||||
|  | 		return ct.LengthValue.Int64, true | ||||||
|  | 	} | ||||||
|  | 	return ct.SQLColumnType.Length() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DecimalSize returns the scale and precision of a decimal type.
 | ||||||
|  | func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { | ||||||
|  | 	if ct.DecimalSizeValue.Valid { | ||||||
|  | 		return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true | ||||||
|  | 	} | ||||||
|  | 	return ct.SQLColumnType.DecimalSize() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Nullable reports whether the column may be null.
 | ||||||
|  | func (ct ColumnType) Nullable() (nullable bool, ok bool) { | ||||||
|  | 	if ct.NullableValue.Valid { | ||||||
|  | 		return ct.NullableValue.Bool, true | ||||||
|  | 	} | ||||||
|  | 	return ct.SQLColumnType.Nullable() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Unique reports whether the column may be unique.
 | ||||||
|  | func (ct ColumnType) Unique() (unique bool, ok bool) { | ||||||
|  | 	return ct.UniqueValue.Bool, ct.UniqueValue.Valid | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ScanType returns a Go type suitable for scanning into using Rows.Scan.
 | ||||||
|  | func (ct ColumnType) ScanType() reflect.Type { | ||||||
|  | 	if ct.ScanTypeValue != nil { | ||||||
|  | 		return ct.ScanTypeValue | ||||||
|  | 	} | ||||||
|  | 	return ct.SQLColumnType.ScanType() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Comment returns the comment of current column.
 | ||||||
|  | func (ct ColumnType) Comment() (value string, ok bool) { | ||||||
|  | 	return ct.CommentValue.String, ct.CommentValue.Valid | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // DefaultValue returns the default value of current column.
 | ||||||
|  | func (ct ColumnType) DefaultValue() (value string, ok bool) { | ||||||
|  | 	return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid | ||||||
|  | } | ||||||
| @ -30,10 +30,12 @@ type Config struct { | |||||||
| 	gorm.Dialector | 	gorm.Dialector | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GormDataTypeInterface gorm data type interface
 | ||||||
| type GormDataTypeInterface interface { | type GormDataTypeInterface interface { | ||||||
| 	GormDBDataType(*gorm.DB, *schema.Field) string | 	GormDBDataType(*gorm.DB, *schema.Field) string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RunWithValue run migration with statement value
 | ||||||
| func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { | func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { | ||||||
| 	stmt := &gorm.Statement{DB: m.DB} | 	stmt := &gorm.Statement{DB: m.DB} | ||||||
| 	if m.DB.Statement != nil { | 	if m.DB.Statement != nil { | ||||||
| @ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error | |||||||
| 	return fc(stmt) | 	return fc(stmt) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DataTypeOf return field's db data type
 | ||||||
| func (m Migrator) DataTypeOf(field *schema.Field) string { | func (m Migrator) DataTypeOf(field *schema.Field) string { | ||||||
| 	fieldValue := reflect.New(field.IndirectFieldType) | 	fieldValue := reflect.New(field.IndirectFieldType) | ||||||
| 	if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { | 	if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { | ||||||
| @ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string { | |||||||
| 	return m.Dialector.DataTypeOf(field) | 	return m.Dialector.DataTypeOf(field) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // FullDataTypeOf returns field's db full data type
 | ||||||
| func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { | func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { | ||||||
| 	expr.SQL = m.DataTypeOf(field) | 	expr.SQL = m.DataTypeOf(field) | ||||||
| 
 | 
 | ||||||
| @ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // AutoMigrate
 | // AutoMigrate auto migrate values
 | ||||||
| func (m Migrator) AutoMigrate(values ...interface{}) error { | func (m Migrator) AutoMigrate(values ...interface{}) error { | ||||||
| 	for _, value := range m.ReorderModels(values, true) { | 	for _, value := range m.ReorderModels(values, true) { | ||||||
| 		tx := m.DB.Session(&gorm.Session{}) | 		tx := m.DB.Session(&gorm.Session{}) | ||||||
| @ -156,12 +160,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GetTables returns tables
 | ||||||
| func (m Migrator) GetTables() (tableList []string, err error) { | func (m Migrator) GetTables() (tableList []string, err error) { | ||||||
| 	err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). | 	err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). | ||||||
| 		Scan(&tableList).Error | 		Scan(&tableList).Error | ||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CreateTable create table in database for values
 | ||||||
| func (m Migrator) CreateTable(values ...interface{}) error { | func (m Migrator) CreateTable(values ...interface{}) error { | ||||||
| 	for _, value := range m.ReorderModels(values, false) { | 	for _, value := range m.ReorderModels(values, false) { | ||||||
| 		tx := m.DB.Session(&gorm.Session{}) | 		tx := m.DB.Session(&gorm.Session{}) | ||||||
| @ -252,6 +258,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DropTable drop table for values
 | ||||||
| func (m Migrator) DropTable(values ...interface{}) error { | func (m Migrator) DropTable(values ...interface{}) error { | ||||||
| 	values = m.ReorderModels(values, false) | 	values = m.ReorderModels(values, false) | ||||||
| 	for i := len(values) - 1; i >= 0; i-- { | 	for i := len(values) - 1; i >= 0; i-- { | ||||||
| @ -265,6 +272,7 @@ func (m Migrator) DropTable(values ...interface{}) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // HasTable returns table exists or not for value, value could be a struct or string
 | ||||||
| func (m Migrator) HasTable(value interface{}) bool { | func (m Migrator) HasTable(value interface{}) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 
 | 
 | ||||||
| @ -276,6 +284,7 @@ func (m Migrator) HasTable(value interface{}) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RenameTable rename table from oldName to newName
 | ||||||
| func (m Migrator) RenameTable(oldName, newName interface{}) error { | func (m Migrator) RenameTable(oldName, newName interface{}) error { | ||||||
| 	var oldTable, newTable interface{} | 	var oldTable, newTable interface{} | ||||||
| 	if v, ok := oldName.(string); ok { | 	if v, ok := oldName.(string); ok { | ||||||
| @ -303,12 +312,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error { | |||||||
| 	return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error | 	return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (m Migrator) AddColumn(value interface{}, field string) error { | // AddColumn create `name` column for value
 | ||||||
|  | func (m Migrator) AddColumn(value interface{}, name string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		// avoid using the same name field
 | 		// avoid using the same name field
 | ||||||
| 		f := stmt.Schema.LookUpField(field) | 		f := stmt.Schema.LookUpField(name) | ||||||
| 		if f == nil { | 		if f == nil { | ||||||
| 			return fmt.Errorf("failed to look up field with name: %s", field) | 			return fmt.Errorf("failed to look up field with name: %s", name) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if !f.IgnoreMigration { | 		if !f.IgnoreMigration { | ||||||
| @ -322,6 +332,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DropColumn drop value's `name` column
 | ||||||
| func (m Migrator) DropColumn(value interface{}, name string) error { | func (m Migrator) DropColumn(value interface{}, name string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(name); field != nil { | 		if field := stmt.Schema.LookUpField(name); field != nil { | ||||||
| @ -334,6 +345,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // AlterColumn alter value's `field` column' type based on schema definition
 | ||||||
| func (m Migrator) AlterColumn(value interface{}, field string) error { | func (m Migrator) AlterColumn(value interface{}, field string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(field); field != nil { | 		if field := stmt.Schema.LookUpField(field); field != nil { | ||||||
| @ -348,6 +360,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // HasColumn check has column `field` for value or not
 | ||||||
| func (m Migrator) HasColumn(value interface{}, field string) bool { | func (m Migrator) HasColumn(value interface{}, field string) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| @ -366,6 +379,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RenameColumn rename value's field name from oldName to newName
 | ||||||
| func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { | func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if field := stmt.Schema.LookUpField(oldName); field != nil { | 		if field := stmt.Schema.LookUpField(oldName); field != nil { | ||||||
| @ -383,6 +397,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // MigrateColumn migrate column
 | ||||||
| func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { | func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { | ||||||
| 	// found, smart migrate
 | 	// found, smart migrate
 | ||||||
| 	fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) | 	fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) | ||||||
| @ -448,7 +463,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { | |||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		for _, c := range rawColumnTypes { | 		for _, c := range rawColumnTypes { | ||||||
| 			columnTypes = append(columnTypes, c) | 			columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return | 		return | ||||||
| @ -457,10 +472,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { | |||||||
| 	return columnTypes, execErr | 	return columnTypes, execErr | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CreateView create view
 | ||||||
| func (m Migrator) CreateView(name string, option gorm.ViewOption) error { | func (m Migrator) CreateView(name string, option gorm.ViewOption) error { | ||||||
| 	return gorm.ErrNotImplemented | 	return gorm.ErrNotImplemented | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DropView drop view
 | ||||||
| func (m Migrator) DropView(name string) error { | func (m Migrator) DropView(name string) error { | ||||||
| 	return gorm.ErrNotImplemented | 	return gorm.ErrNotImplemented | ||||||
| } | } | ||||||
| @ -487,6 +504,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // GuessConstraintAndTable guess statement's constraint and it's table based on name
 | ||||||
| func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { | func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) { | ||||||
| 	if stmt.Schema == nil { | 	if stmt.Schema == nil { | ||||||
| 		return nil, nil, stmt.Table | 		return nil, nil, stmt.Table | ||||||
| @ -531,6 +549,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ | |||||||
| 	return nil, nil, stmt.Schema.Table | 	return nil, nil, stmt.Schema.Table | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CreateConstraint create constraint
 | ||||||
| func (m Migrator) CreateConstraint(value interface{}, name string) error { | func (m Migrator) CreateConstraint(value interface{}, name string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		constraint, chk, table := m.GuessConstraintAndTable(stmt, name) | 		constraint, chk, table := m.GuessConstraintAndTable(stmt, name) | ||||||
| @ -554,6 +573,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DropConstraint drop constraint
 | ||||||
| func (m Migrator) DropConstraint(value interface{}, name string) error { | func (m Migrator) DropConstraint(value interface{}, name string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		constraint, chk, table := m.GuessConstraintAndTable(stmt, name) | 		constraint, chk, table := m.GuessConstraintAndTable(stmt, name) | ||||||
| @ -566,6 +586,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // HasConstraint check has constraint or not
 | ||||||
| func (m Migrator) HasConstraint(value interface{}, name string) bool { | func (m Migrator) HasConstraint(value interface{}, name string) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| @ -586,6 +607,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // BuildIndexOptions build index options
 | ||||||
| func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { | ||||||
| 	for _, opt := range opts { | 	for _, opt := range opts { | ||||||
| 		str := stmt.Quote(opt.DBName) | 		str := stmt.Quote(opt.DBName) | ||||||
| @ -607,10 +629,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // BuildIndexOptionsInterface build index options interface
 | ||||||
| type BuildIndexOptionsInterface interface { | type BuildIndexOptionsInterface interface { | ||||||
| 	BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} | 	BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CreateIndex create index `name`
 | ||||||
| func (m Migrator) CreateIndex(value interface{}, name string) error { | func (m Migrator) CreateIndex(value interface{}, name string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||||
| @ -642,6 +666,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // DropIndex drop index `name`
 | ||||||
| func (m Migrator) DropIndex(value interface{}, name string) error { | func (m Migrator) DropIndex(value interface{}, name string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		if idx := stmt.Schema.LookIndex(name); idx != nil { | 		if idx := stmt.Schema.LookIndex(name); idx != nil { | ||||||
| @ -652,6 +677,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error { | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // HasIndex check has index `name` or not
 | ||||||
| func (m Migrator) HasIndex(value interface{}, name string) bool { | func (m Migrator) HasIndex(value interface{}, name string) bool { | ||||||
| 	var count int64 | 	var count int64 | ||||||
| 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| @ -669,6 +695,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { | |||||||
| 	return count > 0 | 	return count > 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // RenameIndex rename index from oldName to newName
 | ||||||
| func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { | func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { | ||||||
| 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | 	return m.RunWithValue(value, func(stmt *gorm.Statement) error { | ||||||
| 		return m.DB.Exec( | 		return m.DB.Exec( | ||||||
| @ -678,6 +705,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error | |||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CurrentDatabase returns current database name
 | ||||||
| func (m Migrator) CurrentDatabase() (name string) { | func (m Migrator) CurrentDatabase() (name string) { | ||||||
| 	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) | 	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) | ||||||
| 	return | 	return | ||||||
| @ -781,6 +809,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i | |||||||
| 	return | 	return | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // CurrentTable returns current statement's table expression
 | ||||||
| func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { | func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { | ||||||
| 	if stmt.TableExpr != nil { | 	if stmt.TableExpr != nil { | ||||||
| 		return *stmt.TableExpr | 		return *stmt.TableExpr | ||||||
|  | |||||||
| @ -3,17 +3,16 @@ module gorm.io/gorm/tests | |||||||
| go 1.14 | go 1.14 | ||||||
| 
 | 
 | ||||||
| require ( | require ( | ||||||
| 	github.com/denisenkom/go-mssqldb v0.12.0 // indirect |  | ||||||
| 	github.com/google/uuid v1.3.0 | 	github.com/google/uuid v1.3.0 | ||||||
| 	github.com/jackc/pgx/v4 v4.15.0 // indirect | 	github.com/jackc/pgx/v4 v4.15.0 // indirect | ||||||
| 	github.com/jinzhu/now v1.1.4 | 	github.com/jinzhu/now v1.1.4 | ||||||
| 	github.com/lib/pq v1.10.4 | 	github.com/lib/pq v1.10.4 | ||||||
| 	github.com/mattn/go-sqlite3 v1.14.11 // indirect | 	github.com/mattn/go-sqlite3 v1.14.11 // indirect | ||||||
| 	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect | 	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect | ||||||
| 	gorm.io/driver/mysql v1.2.3 | 	gorm.io/driver/mysql v1.3.0 | ||||||
| 	gorm.io/driver/postgres v1.2.3 | 	gorm.io/driver/postgres v1.3.0 | ||||||
| 	gorm.io/driver/sqlite v1.2.6 | 	gorm.io/driver/sqlite v1.3.0 | ||||||
| 	gorm.io/driver/sqlserver v1.2.1 | 	gorm.io/driver/sqlserver v1.3.0 | ||||||
| 	gorm.io/gorm v1.22.5 | 	gorm.io/gorm v1.22.5 | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSmartMigrateColumn(t *testing.T) { | func TestSmartMigrateColumn(t *testing.T) { | ||||||
| 	fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()] | 	fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] | ||||||
| 
 | 
 | ||||||
| 	type UserMigrateColumn struct { | 	type UserMigrateColumn struct { | ||||||
| 		ID       uint | 		ID       uint | ||||||
| @ -313,9 +313,15 @@ func TestMigrateIndexes(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestMigrateColumns(t *testing.T) { | func TestMigrateColumns(t *testing.T) { | ||||||
|  | 	fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()] | ||||||
|  | 	sqlite := DB.Dialector.Name() == "sqlite" | ||||||
|  | 	sqlserver := DB.Dialector.Name() == "sqlserver" | ||||||
|  | 
 | ||||||
| 	type ColumnStruct struct { | 	type ColumnStruct struct { | ||||||
| 		gorm.Model | 		gorm.Model | ||||||
| 		Name string | 		Name string | ||||||
|  | 		Age  int    `gorm:"default:18;comment:my age"` | ||||||
|  | 		Code string `gorm:"unique"` | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	DB.Migrator().DropTable(&ColumnStruct{}) | 	DB.Migrator().DropTable(&ColumnStruct{}) | ||||||
| @ -340,10 +346,29 @@ func TestMigrateColumns(t *testing.T) { | |||||||
| 		stmt.Parse(&ColumnStruct2{}) | 		stmt.Parse(&ColumnStruct2{}) | ||||||
| 
 | 
 | ||||||
| 		for _, columnType := range columnTypes { | 		for _, columnType := range columnTypes { | ||||||
| 			if columnType.Name() == "name" { | 			switch columnType.Name() { | ||||||
|  | 			case "id": | ||||||
|  | 				if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v { | ||||||
|  | 					t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) | ||||||
|  | 				} | ||||||
|  | 			case "name": | ||||||
| 				dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) | 				dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) | ||||||
| 				if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { | 				if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { | ||||||
| 					t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType) | 					t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) | ||||||
|  | 				} | ||||||
|  | 				if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 { | ||||||
|  | 					t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) | ||||||
|  | 				} | ||||||
|  | 			case "age": | ||||||
|  | 				if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" { | ||||||
|  | 					t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) | ||||||
|  | 				} | ||||||
|  | 				if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" { | ||||||
|  | 					t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) | ||||||
|  | 				} | ||||||
|  | 			case "code": | ||||||
|  | 				if v, ok := columnType.Unique(); (fullSupported || ok) && !v { | ||||||
|  | 					t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
| 		} | 		} | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu