diff --git a/migrator.go b/migrator.go index 162fe680..ac06a144 100644 --- a/migrator.go +++ b/migrator.go @@ -1,8 +1,6 @@ package gorm import ( - "database/sql" - "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) @@ -24,6 +22,14 @@ type ViewOption struct { Query *DB } +type ColumnType interface { + Name() string + DatabaseTypeName() string + Length() (length int64, ok bool) + DecimalSize() (precision int64, scale int64, ok bool) + Nullable() (nullable bool, ok bool) +} + type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error @@ -42,10 +48,10 @@ type Migrator interface { AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error - MigrateColumn(dst interface{}, field *schema.Field, columnType *sql.ColumnType) error + MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error - ColumnTypes(dst interface{}) ([]*sql.ColumnType, error) + ColumnTypes(dst interface{}) ([]ColumnType, error) // Views CreateView(name string, option ViewOption) error diff --git a/migrator/migrator.go b/migrator/migrator.go index f390ff9f..ca8e63ca 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -2,7 +2,6 @@ package migrator import ( "context" - "database/sql" "fmt" "reflect" "regexp" @@ -92,7 +91,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { columnTypes, _ := m.DB.Migrator().ColumnTypes(value) for _, field := range stmt.Schema.FieldsByDBName { - var foundColumn *sql.ColumnType + var foundColumn gorm.ColumnType for _, columnType := range columnTypes { if columnType.Name() == field.DBName { @@ -352,7 +351,7 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } -func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType *sql.ColumnType) error { +func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { // found, smart migrate fullDataType := strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL) realDataType := strings.ToLower(columnType.DatabaseTypeName()) @@ -395,12 +394,18 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy return nil } -func (m Migrator) ColumnTypes(value interface{}) (columnTypes []*sql.ColumnType, err error) { +func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { + columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err == nil { defer rows.Close() - columnTypes, err = rows.ColumnTypes() + rawColumnTypes, err := rows.ColumnTypes() + if err == nil { + for _, c := range rawColumnTypes { + columnTypes = append(columnTypes, c) + } + } } return err })