Enhance migrator Columntype interface (#5088)
* Update Migrator ColumnType interface * Update MigrateColumn Test * Upgrade test drivers * Fix typo
This commit is contained in:
		
							parent
							
								
									5299a0f9da
								
							
						
					
					
						commit
						b17c550011
					
				
							
								
								
									
										13
									
								
								migrator.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								migrator.go
									
									
									
									
									
								
							@ -1,6 +1,8 @@
 | 
				
			|||||||
package gorm
 | 
					package gorm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"gorm.io/gorm/clause"
 | 
						"gorm.io/gorm/clause"
 | 
				
			||||||
	"gorm.io/gorm/schema"
 | 
						"gorm.io/gorm/schema"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@ -33,14 +35,23 @@ type ViewOption struct {
 | 
				
			|||||||
	Query       *DB
 | 
						Query       *DB
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ColumnType column type interface
 | 
				
			||||||
type ColumnType interface {
 | 
					type ColumnType interface {
 | 
				
			||||||
	Name() string
 | 
						Name() string
 | 
				
			||||||
	DatabaseTypeName() string
 | 
						DatabaseTypeName() string                 // varchar
 | 
				
			||||||
 | 
						ColumnType() (columnType string, ok bool) // varchar(64)
 | 
				
			||||||
 | 
						PrimaryKey() (isPrimaryKey bool, ok bool)
 | 
				
			||||||
 | 
						AutoIncrement() (isAutoIncrement bool, ok bool)
 | 
				
			||||||
	Length() (length int64, ok bool)
 | 
						Length() (length int64, ok bool)
 | 
				
			||||||
	DecimalSize() (precision int64, scale int64, ok bool)
 | 
						DecimalSize() (precision int64, scale int64, ok bool)
 | 
				
			||||||
	Nullable() (nullable bool, ok bool)
 | 
						Nullable() (nullable bool, ok bool)
 | 
				
			||||||
 | 
						Unique() (unique bool, ok bool)
 | 
				
			||||||
 | 
						ScanType() reflect.Type
 | 
				
			||||||
 | 
						Comment() (value string, ok bool)
 | 
				
			||||||
 | 
						DefaultValue() (value string, ok bool)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Migrator migrator interface
 | 
				
			||||||
type Migrator interface {
 | 
					type Migrator interface {
 | 
				
			||||||
	// AutoMigrate
 | 
						// AutoMigrate
 | 
				
			||||||
	AutoMigrate(dst ...interface{}) error
 | 
						AutoMigrate(dst ...interface{}) error
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										107
									
								
								migrator/column_type.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								migrator/column_type.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,107 @@
 | 
				
			|||||||
 | 
					package migrator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"database/sql"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ColumnType column type implements ColumnType interface
 | 
				
			||||||
 | 
					type ColumnType struct {
 | 
				
			||||||
 | 
						SQLColumnType      *sql.ColumnType
 | 
				
			||||||
 | 
						NameValue          sql.NullString
 | 
				
			||||||
 | 
						DataTypeValue      sql.NullString
 | 
				
			||||||
 | 
						ColumnTypeValue    sql.NullString
 | 
				
			||||||
 | 
						PrimayKeyValue     sql.NullBool
 | 
				
			||||||
 | 
						UniqueValue        sql.NullBool
 | 
				
			||||||
 | 
						AutoIncrementValue sql.NullBool
 | 
				
			||||||
 | 
						LengthValue        sql.NullInt64
 | 
				
			||||||
 | 
						DecimalSizeValue   sql.NullInt64
 | 
				
			||||||
 | 
						ScaleValue         sql.NullInt64
 | 
				
			||||||
 | 
						NullableValue      sql.NullBool
 | 
				
			||||||
 | 
						ScanTypeValue      reflect.Type
 | 
				
			||||||
 | 
						CommentValue       sql.NullString
 | 
				
			||||||
 | 
						DefaultValueValue  sql.NullString
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Name returns the name or alias of the column.
 | 
				
			||||||
 | 
					func (ct ColumnType) Name() string {
 | 
				
			||||||
 | 
						if ct.NameValue.Valid {
 | 
				
			||||||
 | 
							return ct.NameValue.String
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ct.SQLColumnType.Name()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DatabaseTypeName returns the database system name of the column type. If an empty
 | 
				
			||||||
 | 
					// string is returned, then the driver type name is not supported.
 | 
				
			||||||
 | 
					// Consult your driver documentation for a list of driver data types. Length specifiers
 | 
				
			||||||
 | 
					// are not included.
 | 
				
			||||||
 | 
					// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
 | 
				
			||||||
 | 
					// "INT", and "BIGINT".
 | 
				
			||||||
 | 
					func (ct ColumnType) DatabaseTypeName() string {
 | 
				
			||||||
 | 
						if ct.DataTypeValue.Valid {
 | 
				
			||||||
 | 
							return ct.DataTypeValue.String
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ct.SQLColumnType.DatabaseTypeName()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ColumnType returns the database type of the column. lke `varchar(16)`
 | 
				
			||||||
 | 
					func (ct ColumnType) ColumnType() (columnType string, ok bool) {
 | 
				
			||||||
 | 
						return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// PrimaryKey returns the column is primary key or not.
 | 
				
			||||||
 | 
					func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) {
 | 
				
			||||||
 | 
						return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AutoIncrement returns the column is auto increment or not.
 | 
				
			||||||
 | 
					func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) {
 | 
				
			||||||
 | 
						return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Length returns the column type length for variable length column types
 | 
				
			||||||
 | 
					func (ct ColumnType) Length() (length int64, ok bool) {
 | 
				
			||||||
 | 
						if ct.LengthValue.Valid {
 | 
				
			||||||
 | 
							return ct.LengthValue.Int64, true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ct.SQLColumnType.Length()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DecimalSize returns the scale and precision of a decimal type.
 | 
				
			||||||
 | 
					func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) {
 | 
				
			||||||
 | 
						if ct.DecimalSizeValue.Valid {
 | 
				
			||||||
 | 
							return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ct.SQLColumnType.DecimalSize()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Nullable reports whether the column may be null.
 | 
				
			||||||
 | 
					func (ct ColumnType) Nullable() (nullable bool, ok bool) {
 | 
				
			||||||
 | 
						if ct.NullableValue.Valid {
 | 
				
			||||||
 | 
							return ct.NullableValue.Bool, true
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ct.SQLColumnType.Nullable()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Unique reports whether the column may be unique.
 | 
				
			||||||
 | 
					func (ct ColumnType) Unique() (unique bool, ok bool) {
 | 
				
			||||||
 | 
						return ct.UniqueValue.Bool, ct.UniqueValue.Valid
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ScanType returns a Go type suitable for scanning into using Rows.Scan.
 | 
				
			||||||
 | 
					func (ct ColumnType) ScanType() reflect.Type {
 | 
				
			||||||
 | 
						if ct.ScanTypeValue != nil {
 | 
				
			||||||
 | 
							return ct.ScanTypeValue
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return ct.SQLColumnType.ScanType()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Comment returns the comment of current column.
 | 
				
			||||||
 | 
					func (ct ColumnType) Comment() (value string, ok bool) {
 | 
				
			||||||
 | 
						return ct.CommentValue.String, ct.CommentValue.Valid
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DefaultValue returns the default value of current column.
 | 
				
			||||||
 | 
					func (ct ColumnType) DefaultValue() (value string, ok bool) {
 | 
				
			||||||
 | 
						return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@ -30,10 +30,12 @@ type Config struct {
 | 
				
			|||||||
	gorm.Dialector
 | 
						gorm.Dialector
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GormDataTypeInterface gorm data type interface
 | 
				
			||||||
type GormDataTypeInterface interface {
 | 
					type GormDataTypeInterface interface {
 | 
				
			||||||
	GormDBDataType(*gorm.DB, *schema.Field) string
 | 
						GormDBDataType(*gorm.DB, *schema.Field) string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RunWithValue run migration with statement value
 | 
				
			||||||
func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
 | 
					func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error {
 | 
				
			||||||
	stmt := &gorm.Statement{DB: m.DB}
 | 
						stmt := &gorm.Statement{DB: m.DB}
 | 
				
			||||||
	if m.DB.Statement != nil {
 | 
						if m.DB.Statement != nil {
 | 
				
			||||||
@ -50,6 +52,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error
 | 
				
			|||||||
	return fc(stmt)
 | 
						return fc(stmt)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DataTypeOf return field's db data type
 | 
				
			||||||
func (m Migrator) DataTypeOf(field *schema.Field) string {
 | 
					func (m Migrator) DataTypeOf(field *schema.Field) string {
 | 
				
			||||||
	fieldValue := reflect.New(field.IndirectFieldType)
 | 
						fieldValue := reflect.New(field.IndirectFieldType)
 | 
				
			||||||
	if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
 | 
						if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok {
 | 
				
			||||||
@ -61,6 +64,7 @@ func (m Migrator) DataTypeOf(field *schema.Field) string {
 | 
				
			|||||||
	return m.Dialector.DataTypeOf(field)
 | 
						return m.Dialector.DataTypeOf(field)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// FullDataTypeOf returns field's db full data type
 | 
				
			||||||
func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
 | 
					func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
 | 
				
			||||||
	expr.SQL = m.DataTypeOf(field)
 | 
						expr.SQL = m.DataTypeOf(field)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -85,7 +89,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// AutoMigrate
 | 
					// AutoMigrate auto migrate values
 | 
				
			||||||
func (m Migrator) AutoMigrate(values ...interface{}) error {
 | 
					func (m Migrator) AutoMigrate(values ...interface{}) error {
 | 
				
			||||||
	for _, value := range m.ReorderModels(values, true) {
 | 
						for _, value := range m.ReorderModels(values, true) {
 | 
				
			||||||
		tx := m.DB.Session(&gorm.Session{})
 | 
							tx := m.DB.Session(&gorm.Session{})
 | 
				
			||||||
@ -156,12 +160,14 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GetTables returns tables
 | 
				
			||||||
func (m Migrator) GetTables() (tableList []string, err error) {
 | 
					func (m Migrator) GetTables() (tableList []string, err error) {
 | 
				
			||||||
	err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
 | 
						err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()).
 | 
				
			||||||
		Scan(&tableList).Error
 | 
							Scan(&tableList).Error
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CreateTable create table in database for values
 | 
				
			||||||
func (m Migrator) CreateTable(values ...interface{}) error {
 | 
					func (m Migrator) CreateTable(values ...interface{}) error {
 | 
				
			||||||
	for _, value := range m.ReorderModels(values, false) {
 | 
						for _, value := range m.ReorderModels(values, false) {
 | 
				
			||||||
		tx := m.DB.Session(&gorm.Session{})
 | 
							tx := m.DB.Session(&gorm.Session{})
 | 
				
			||||||
@ -252,6 +258,7 @@ func (m Migrator) CreateTable(values ...interface{}) error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropTable drop table for values
 | 
				
			||||||
func (m Migrator) DropTable(values ...interface{}) error {
 | 
					func (m Migrator) DropTable(values ...interface{}) error {
 | 
				
			||||||
	values = m.ReorderModels(values, false)
 | 
						values = m.ReorderModels(values, false)
 | 
				
			||||||
	for i := len(values) - 1; i >= 0; i-- {
 | 
						for i := len(values) - 1; i >= 0; i-- {
 | 
				
			||||||
@ -265,6 +272,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HasTable returns table exists or not for value, value could be a struct or string
 | 
				
			||||||
func (m Migrator) HasTable(value interface{}) bool {
 | 
					func (m Migrator) HasTable(value interface{}) bool {
 | 
				
			||||||
	var count int64
 | 
						var count int64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -276,6 +284,7 @@ func (m Migrator) HasTable(value interface{}) bool {
 | 
				
			|||||||
	return count > 0
 | 
						return count > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RenameTable rename table from oldName to newName
 | 
				
			||||||
func (m Migrator) RenameTable(oldName, newName interface{}) error {
 | 
					func (m Migrator) RenameTable(oldName, newName interface{}) error {
 | 
				
			||||||
	var oldTable, newTable interface{}
 | 
						var oldTable, newTable interface{}
 | 
				
			||||||
	if v, ok := oldName.(string); ok {
 | 
						if v, ok := oldName.(string); ok {
 | 
				
			||||||
@ -303,12 +312,13 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
 | 
				
			|||||||
	return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
 | 
						return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (m Migrator) AddColumn(value interface{}, field string) error {
 | 
					// AddColumn create `name` column for value
 | 
				
			||||||
 | 
					func (m Migrator) AddColumn(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		// avoid using the same name field
 | 
							// avoid using the same name field
 | 
				
			||||||
		f := stmt.Schema.LookUpField(field)
 | 
							f := stmt.Schema.LookUpField(name)
 | 
				
			||||||
		if f == nil {
 | 
							if f == nil {
 | 
				
			||||||
			return fmt.Errorf("failed to look up field with name: %s", field)
 | 
								return fmt.Errorf("failed to look up field with name: %s", name)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !f.IgnoreMigration {
 | 
							if !f.IgnoreMigration {
 | 
				
			||||||
@ -322,6 +332,7 @@ func (m Migrator) AddColumn(value interface{}, field string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropColumn drop value's `name` column
 | 
				
			||||||
func (m Migrator) DropColumn(value interface{}, name string) error {
 | 
					func (m Migrator) DropColumn(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if field := stmt.Schema.LookUpField(name); field != nil {
 | 
							if field := stmt.Schema.LookUpField(name); field != nil {
 | 
				
			||||||
@ -334,6 +345,7 @@ func (m Migrator) DropColumn(value interface{}, name string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AlterColumn alter value's `field` column' type based on schema definition
 | 
				
			||||||
func (m Migrator) AlterColumn(value interface{}, field string) error {
 | 
					func (m Migrator) AlterColumn(value interface{}, field string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if field := stmt.Schema.LookUpField(field); field != nil {
 | 
							if field := stmt.Schema.LookUpField(field); field != nil {
 | 
				
			||||||
@ -348,6 +360,7 @@ func (m Migrator) AlterColumn(value interface{}, field string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HasColumn check has column `field` for value or not
 | 
				
			||||||
func (m Migrator) HasColumn(value interface{}, field string) bool {
 | 
					func (m Migrator) HasColumn(value interface{}, field string) bool {
 | 
				
			||||||
	var count int64
 | 
						var count int64
 | 
				
			||||||
	m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
@ -366,6 +379,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
 | 
				
			|||||||
	return count > 0
 | 
						return count > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RenameColumn rename value's field name from oldName to newName
 | 
				
			||||||
func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
 | 
					func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if field := stmt.Schema.LookUpField(oldName); field != nil {
 | 
							if field := stmt.Schema.LookUpField(oldName); field != nil {
 | 
				
			||||||
@ -383,6 +397,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// MigrateColumn migrate column
 | 
				
			||||||
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
 | 
					func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
 | 
				
			||||||
	// found, smart migrate
 | 
						// found, smart migrate
 | 
				
			||||||
	fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
 | 
						fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)
 | 
				
			||||||
@ -448,7 +463,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for _, c := range rawColumnTypes {
 | 
							for _, c := range rawColumnTypes {
 | 
				
			||||||
			columnTypes = append(columnTypes, c)
 | 
								columnTypes = append(columnTypes, ColumnType{SQLColumnType: c})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@ -457,10 +472,12 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
 | 
				
			|||||||
	return columnTypes, execErr
 | 
						return columnTypes, execErr
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CreateView create view
 | 
				
			||||||
func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
 | 
					func (m Migrator) CreateView(name string, option gorm.ViewOption) error {
 | 
				
			||||||
	return gorm.ErrNotImplemented
 | 
						return gorm.ErrNotImplemented
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropView drop view
 | 
				
			||||||
func (m Migrator) DropView(name string) error {
 | 
					func (m Migrator) DropView(name string) error {
 | 
				
			||||||
	return gorm.ErrNotImplemented
 | 
						return gorm.ErrNotImplemented
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@ -487,6 +504,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GuessConstraintAndTable guess statement's constraint and it's table based on name
 | 
				
			||||||
func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
 | 
					func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_ *schema.Constraint, _ *schema.Check, table string) {
 | 
				
			||||||
	if stmt.Schema == nil {
 | 
						if stmt.Schema == nil {
 | 
				
			||||||
		return nil, nil, stmt.Table
 | 
							return nil, nil, stmt.Table
 | 
				
			||||||
@ -531,6 +549,7 @@ func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (_
 | 
				
			|||||||
	return nil, nil, stmt.Schema.Table
 | 
						return nil, nil, stmt.Schema.Table
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CreateConstraint create constraint
 | 
				
			||||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
 | 
					func (m Migrator) CreateConstraint(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
 | 
							constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
 | 
				
			||||||
@ -554,6 +573,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropConstraint drop constraint
 | 
				
			||||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
 | 
					func (m Migrator) DropConstraint(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
 | 
							constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
 | 
				
			||||||
@ -566,6 +586,7 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HasConstraint check has constraint or not
 | 
				
			||||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
 | 
					func (m Migrator) HasConstraint(value interface{}, name string) bool {
 | 
				
			||||||
	var count int64
 | 
						var count int64
 | 
				
			||||||
	m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
@ -586,6 +607,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
 | 
				
			|||||||
	return count > 0
 | 
						return count > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// BuildIndexOptions build index options
 | 
				
			||||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
 | 
					func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
 | 
				
			||||||
	for _, opt := range opts {
 | 
						for _, opt := range opts {
 | 
				
			||||||
		str := stmt.Quote(opt.DBName)
 | 
							str := stmt.Quote(opt.DBName)
 | 
				
			||||||
@ -607,10 +629,12 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// BuildIndexOptionsInterface build index options interface
 | 
				
			||||||
type BuildIndexOptionsInterface interface {
 | 
					type BuildIndexOptionsInterface interface {
 | 
				
			||||||
	BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
 | 
						BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CreateIndex create index `name`
 | 
				
			||||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
 | 
					func (m Migrator) CreateIndex(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
							if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
				
			||||||
@ -642,6 +666,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DropIndex drop index `name`
 | 
				
			||||||
func (m Migrator) DropIndex(value interface{}, name string) error {
 | 
					func (m Migrator) DropIndex(value interface{}, name string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
							if idx := stmt.Schema.LookIndex(name); idx != nil {
 | 
				
			||||||
@ -652,6 +677,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HasIndex check has index `name` or not
 | 
				
			||||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
 | 
					func (m Migrator) HasIndex(value interface{}, name string) bool {
 | 
				
			||||||
	var count int64
 | 
						var count int64
 | 
				
			||||||
	m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
@ -669,6 +695,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
 | 
				
			|||||||
	return count > 0
 | 
						return count > 0
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RenameIndex rename index from oldName to newName
 | 
				
			||||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
 | 
					func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
 | 
				
			||||||
	return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
						return m.RunWithValue(value, func(stmt *gorm.Statement) error {
 | 
				
			||||||
		return m.DB.Exec(
 | 
							return m.DB.Exec(
 | 
				
			||||||
@ -678,6 +705,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error
 | 
				
			|||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CurrentDatabase returns current database name
 | 
				
			||||||
func (m Migrator) CurrentDatabase() (name string) {
 | 
					func (m Migrator) CurrentDatabase() (name string) {
 | 
				
			||||||
	m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
 | 
						m.DB.Raw("SELECT DATABASE()").Row().Scan(&name)
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
@ -781,6 +809,7 @@ func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []i
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CurrentTable returns current statement's table expression
 | 
				
			||||||
func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
 | 
					func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} {
 | 
				
			||||||
	if stmt.TableExpr != nil {
 | 
						if stmt.TableExpr != nil {
 | 
				
			||||||
		return *stmt.TableExpr
 | 
							return *stmt.TableExpr
 | 
				
			||||||
 | 
				
			|||||||
@ -3,17 +3,16 @@ module gorm.io/gorm/tests
 | 
				
			|||||||
go 1.14
 | 
					go 1.14
 | 
				
			||||||
 | 
					
 | 
				
			||||||
require (
 | 
					require (
 | 
				
			||||||
	github.com/denisenkom/go-mssqldb v0.12.0 // indirect
 | 
					 | 
				
			||||||
	github.com/google/uuid v1.3.0
 | 
						github.com/google/uuid v1.3.0
 | 
				
			||||||
	github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
						github.com/jackc/pgx/v4 v4.15.0 // indirect
 | 
				
			||||||
	github.com/jinzhu/now v1.1.4
 | 
						github.com/jinzhu/now v1.1.4
 | 
				
			||||||
	github.com/lib/pq v1.10.4
 | 
						github.com/lib/pq v1.10.4
 | 
				
			||||||
	github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
						github.com/mattn/go-sqlite3 v1.14.11 // indirect
 | 
				
			||||||
	golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
						golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect
 | 
				
			||||||
	gorm.io/driver/mysql v1.2.3
 | 
						gorm.io/driver/mysql v1.3.0
 | 
				
			||||||
	gorm.io/driver/postgres v1.2.3
 | 
						gorm.io/driver/postgres v1.3.0
 | 
				
			||||||
	gorm.io/driver/sqlite v1.2.6
 | 
						gorm.io/driver/sqlite v1.3.0
 | 
				
			||||||
	gorm.io/driver/sqlserver v1.2.1
 | 
						gorm.io/driver/sqlserver v1.3.0
 | 
				
			||||||
	gorm.io/gorm v1.22.5
 | 
						gorm.io/gorm v1.22.5
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -92,7 +92,7 @@ func TestAutoMigrateSelfReferential(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestSmartMigrateColumn(t *testing.T) {
 | 
					func TestSmartMigrateColumn(t *testing.T) {
 | 
				
			||||||
	fullSupported := map[string]bool{"mysql": true}[DB.Dialector.Name()]
 | 
						fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	type UserMigrateColumn struct {
 | 
						type UserMigrateColumn struct {
 | 
				
			||||||
		ID       uint
 | 
							ID       uint
 | 
				
			||||||
@ -313,9 +313,15 @@ func TestMigrateIndexes(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestMigrateColumns(t *testing.T) {
 | 
					func TestMigrateColumns(t *testing.T) {
 | 
				
			||||||
 | 
						fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()]
 | 
				
			||||||
 | 
						sqlite := DB.Dialector.Name() == "sqlite"
 | 
				
			||||||
 | 
						sqlserver := DB.Dialector.Name() == "sqlserver"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	type ColumnStruct struct {
 | 
						type ColumnStruct struct {
 | 
				
			||||||
		gorm.Model
 | 
							gorm.Model
 | 
				
			||||||
		Name string
 | 
							Name string
 | 
				
			||||||
 | 
							Age  int    `gorm:"default:18;comment:my age"`
 | 
				
			||||||
 | 
							Code string `gorm:"unique"`
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	DB.Migrator().DropTable(&ColumnStruct{})
 | 
						DB.Migrator().DropTable(&ColumnStruct{})
 | 
				
			||||||
@ -340,10 +346,29 @@ func TestMigrateColumns(t *testing.T) {
 | 
				
			|||||||
		stmt.Parse(&ColumnStruct2{})
 | 
							stmt.Parse(&ColumnStruct2{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		for _, columnType := range columnTypes {
 | 
							for _, columnType := range columnTypes {
 | 
				
			||||||
			if columnType.Name() == "name" {
 | 
								switch columnType.Name() {
 | 
				
			||||||
 | 
								case "id":
 | 
				
			||||||
 | 
									if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v {
 | 
				
			||||||
 | 
										t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "name":
 | 
				
			||||||
				dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
									dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name()))
 | 
				
			||||||
				if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
									if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) {
 | 
				
			||||||
					t.Errorf("column type should be correct, name: %v, length: %v, expects: %v", columnType.Name(), columnType.DatabaseTypeName(), dataType)
 | 
										t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 {
 | 
				
			||||||
 | 
										t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "age":
 | 
				
			||||||
 | 
									if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" {
 | 
				
			||||||
 | 
										t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" {
 | 
				
			||||||
 | 
										t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case "code":
 | 
				
			||||||
 | 
									if v, ok := columnType.Unique(); (fullSupported || ok) && !v {
 | 
				
			||||||
 | 
										t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user