diff --git a/finisher_api.go b/finisher_api.go index 2951fdef..857f9419 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -326,6 +326,15 @@ func (db *DB) Count(count *int64) (tx *DB) { defer tx.Statement.AddClause(clause.Select{}) } + if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { + if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { + delete(db.Statement.Clauses, "ORDER BY") + defer func() { + db.Statement.Clauses["ORDER BY"] = orderByClause + }() + } + } + tx.Statement.Dest = count tx.callbacks.Query().Execute(tx) if tx.RowsAffected != 1 { @@ -356,9 +365,13 @@ func (db *DB) Rows() (*sql.Rows, error) { // Scan scan value to a struct func (db *DB) Scan(dest interface{}) (tx *DB) { - currentLogger, newLogger := db.Logger, logger.Recorder.New() + config := *db.Config + currentLogger, newLogger := config.Logger, logger.Recorder.New() + config.Logger = newLogger + tx = db.getInstance() - tx.Logger = newLogger + tx.Config = &config + if rows, err := tx.Rows(); err != nil { tx.AddError(err) } else { diff --git a/migrator/migrator.go b/migrator/migrator.go index c564cb67..9493a00c 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -32,6 +32,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { stmt.Table = m.DB.Statement.Table + stmt.TableExpr = m.DB.Statement.TableExpr } if table, ok := value.(string); ok { @@ -161,6 +162,10 @@ func (m Migrator) CreateTable(values ...interface{}) error { hasPrimaryKeyInDataType bool ) + if stmt.TableExpr != nil { + values[0] = *stmt.TableExpr + } + for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] createTableSQL += "? ?" @@ -370,9 +375,9 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy alterColumn = true } else { // has size in data type and not equal - matches := regexp.MustCompile(`[^\d](\d+)[^\d]`).FindAllStringSubmatch(realDataType, -1) - matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]`).FindAllStringSubmatch(fullDataType, -1) - if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size)) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { + matches := regexp.MustCompile(`[^\d](\d+)[^\d]?`).FindAllStringSubmatch(realDataType, -1) + matches2 := regexp.MustCompile(`[^\d]*(\d+)[^\d]?`).FindAllStringSubmatch(fullDataType, -1) + if (len(matches) == 1 && matches[0][1] != fmt.Sprint(field.Size) || !field.PrimaryKey) && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length)) { alterColumn = true } } diff --git a/schema/field.go b/schema/field.go index e7f5b708..b303fb30 100644 --- a/schema/field.go +++ b/schema/field.go @@ -762,13 +762,15 @@ func (field *Field) setupValuerAndSetter() { // pointer scanner field.Set = func(value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) - if reflectV.Type().AssignableTo(field.FieldType) { + if !reflectV.IsValid() { + field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() { + if reflectV.IsNil() || !reflectV.IsValid() { field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) } else { - err = field.Set(value, reflectV.Elem().Interface()) + return field.Set(value, reflectV.Elem().Interface()) } } else { fieldValue := field.ReflectValueOf(value) diff --git a/tests/count_test.go b/tests/count_test.go index 0d348227..41bad71d 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -70,6 +70,11 @@ func TestCount(t *testing.T) { var count4 int64 if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { + t.Errorf("count with join, got error: %v, count %v", err, count4) + } + + var count5 int64 + if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } } diff --git a/tests/go.mod b/tests/go.mod index ddb1773b..3fa011f1 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v1.0.2 - gorm.io/driver/postgres v1.0.4 + gorm.io/driver/mysql v1.0.3 + gorm.io/driver/postgres v1.0.5 gorm.io/driver/sqlite v1.1.3 gorm.io/driver/sqlserver v1.0.5 - gorm.io/gorm v1.20.2 + gorm.io/gorm v1.20.4 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 4cc8a7c3..275fe634 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -48,11 +48,13 @@ func TestMigrate(t *testing.T) { } func TestSmartMigrateColumn(t *testing.T) { + fullSupported := map[string]bool{"mysql": true, "postgres": true}[DB.Dialector.Name()] + type UserMigrateColumn struct { ID uint Name string Salary float64 - Birthday time.Time + Birthday time.Time `gorm:"precision:4"` } DB.Migrator().DropTable(&UserMigrateColumn{}) @@ -78,15 +80,15 @@ func TestSmartMigrateColumn(t *testing.T) { for _, columnType := range columnTypes { switch columnType.Name() { case "name": - if length, _ := columnType.Length(); length != 0 && length != 128 { + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": - if precision, o, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) } case "birthday": - if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 2 { + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } } @@ -111,15 +113,15 @@ func TestSmartMigrateColumn(t *testing.T) { for _, columnType := range columnTypes { switch columnType.Name() { case "name": - if length, _ := columnType.Length(); length != 0 && length != 256 { + if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": - if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("salary's precision should be 2, but got %v", precision) } case "birthday": - if precision, _, _ := columnType.DecimalSize(); precision != 0 && precision != 3 { + if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } }