Fix Pluck with Time and Scanner
This commit is contained in:
		
							parent
							
								
									c0de3c5051
								
							
						
					
					
						commit
						ba253982bf
					
				
							
								
								
									
										13
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								scan.go
									
									
									
									
									
								
							| @ -5,6 +5,7 @@ import ( | |||||||
| 	"database/sql/driver" | 	"database/sql/driver" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"gorm.io/gorm/schema" | 	"gorm.io/gorm/schema" | ||||||
| ) | ) | ||||||
| @ -82,7 +83,7 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | |||||||
| 			scanIntoMap(mapValue, values, columns) | 			scanIntoMap(mapValue, values, columns) | ||||||
| 			*dest = append(*dest, mapValue) | 			*dest = append(*dest, mapValue) | ||||||
| 		} | 		} | ||||||
| 	case *int, *int64, *uint, *uint64, *float32, *float64, *string: | 	case *int, *int64, *uint, *uint64, *float32, *float64, *string, *time.Time: | ||||||
| 		for initialized || rows.Next() { | 		for initialized || rows.Next() { | ||||||
| 			initialized = false | 			initialized = false | ||||||
| 			db.RowsAffected++ | 			db.RowsAffected++ | ||||||
| @ -134,7 +135,15 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | |||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			// pluck values into slice of data
 | 			// pluck values into slice of data
 | ||||||
| 			isPluck := len(fields) == 1 && reflectValueType.Kind() != reflect.Struct | 			isPluck := false | ||||||
|  | 			if len(fields) == 1 { | ||||||
|  | 				if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); ok { | ||||||
|  | 					isPluck = true | ||||||
|  | 				} else if reflectValueType.Kind() != reflect.Struct || reflectValueType.ConvertibleTo(schema.TimeReflectType) { | ||||||
|  | 					isPluck = true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
| 			for initialized || rows.Next() { | 			for initialized || rows.Next() { | ||||||
| 				initialized = false | 				initialized = false | ||||||
| 				db.RowsAffected++ | 				db.RowsAffected++ | ||||||
|  | |||||||
| @ -18,6 +18,8 @@ type DataType string | |||||||
| 
 | 
 | ||||||
| type TimeType int64 | type TimeType int64 | ||||||
| 
 | 
 | ||||||
|  | var TimeReflectType = reflect.TypeOf(time.Time{}) | ||||||
|  | 
 | ||||||
| const ( | const ( | ||||||
| 	UnixSecond      TimeType = 1 | 	UnixSecond      TimeType = 1 | ||||||
| 	UnixMillisecond TimeType = 2 | 	UnixMillisecond TimeType = 2 | ||||||
| @ -102,7 +104,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 			var getRealFieldValue func(reflect.Value) | 			var getRealFieldValue func(reflect.Value) | ||||||
| 			getRealFieldValue = func(v reflect.Value) { | 			getRealFieldValue = func(v reflect.Value) { | ||||||
| 				rv := reflect.Indirect(v) | 				rv := reflect.Indirect(v) | ||||||
| 				if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { | 				if rv.Kind() == reflect.Struct && !rv.Type().ConvertibleTo(TimeReflectType) { | ||||||
| 					for i := 0; i < rv.Type().NumField(); i++ { | 					for i := 0; i < rv.Type().NumField(); i++ { | ||||||
| 						newFieldType := rv.Type().Field(i).Type | 						newFieldType := rv.Type().Field(i).Type | ||||||
| 						for newFieldType.Kind() == reflect.Ptr { | 						for newFieldType.Kind() == reflect.Ptr { | ||||||
| @ -221,7 +223,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { | |||||||
| 	case reflect.Struct: | 	case reflect.Struct: | ||||||
| 		if _, ok := fieldValue.Interface().(*time.Time); ok { | 		if _, ok := fieldValue.Interface().(*time.Time); ok { | ||||||
| 			field.DataType = Time | 			field.DataType = Time | ||||||
| 		} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(time.Time{})) { | 		} else if fieldValue.Type().ConvertibleTo(TimeReflectType) { | ||||||
| 			field.DataType = Time | 			field.DataType = Time | ||||||
| 		} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { | 		} else if fieldValue.Type().ConvertibleTo(reflect.TypeOf(&time.Time{})) { | ||||||
| 			field.DataType = Time | 			field.DataType = Time | ||||||
|  | |||||||
| @ -1,6 +1,7 @@ | |||||||
| package tests_test | package tests_test | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"regexp" | 	"regexp" | ||||||
| @ -431,6 +432,33 @@ func TestPluck(t *testing.T) { | |||||||
| 			t.Errorf("Unexpected result on pluck id, got %+v", ids) | 			t.Errorf("Unexpected result on pluck id, got %+v", ids) | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	var times []time.Time | ||||||
|  | 	if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { | ||||||
|  | 		t.Errorf("got error when pluck time: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for idx, tv := range times { | ||||||
|  | 		AssertEqual(t, tv, users[idx].CreatedAt) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var ptrtimes []*time.Time | ||||||
|  | 	if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { | ||||||
|  | 		t.Errorf("got error when pluck time: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for idx, tv := range ptrtimes { | ||||||
|  | 		AssertEqual(t, tv, users[idx].CreatedAt) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var nulltimes []sql.NullTime | ||||||
|  | 	if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { | ||||||
|  | 		t.Errorf("got error when pluck time: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for idx, tv := range nulltimes { | ||||||
|  | 		AssertEqual(t, tv.Time, users[idx].CreatedAt) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSelect(t *testing.T) { | func TestSelect(t *testing.T) { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu