Refactor setter, valuer
This commit is contained in:
		
							parent
							
								
									69b851ed95
								
							
						
					
					
						commit
						fb52b97363
					
				
							
								
								
									
										14
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								scan.go
									
									
									
									
									
								
							@ -10,6 +10,7 @@ import (
 | 
			
		||||
	"gorm.io/gorm/schema"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// prepareValues prepare values slice
 | 
			
		||||
func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) {
 | 
			
		||||
	if db.Statement.Schema != nil {
 | 
			
		||||
		for idx, name := range columns {
 | 
			
		||||
@ -99,14 +100,17 @@ func (db *DB) scanIntoStruct(sch *schema.Schema, rows *sql.Rows, reflectValue re
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ScanMode scan data mode
 | 
			
		||||
type ScanMode uint8
 | 
			
		||||
 | 
			
		||||
// scan modes
 | 
			
		||||
const (
 | 
			
		||||
	ScanInitialized         ScanMode = 1 << 0 // 1
 | 
			
		||||
	ScanUpdate              ScanMode = 1 << 1 // 2
 | 
			
		||||
	ScanOnConflictDoNothing ScanMode = 1 << 2 // 4
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Scan scan rows into db statement
 | 
			
		||||
func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
	var (
 | 
			
		||||
		columns, _          = rows.Columns()
 | 
			
		||||
@ -138,7 +142,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
			}
 | 
			
		||||
			scanIntoMap(mapValue, values, columns)
 | 
			
		||||
		}
 | 
			
		||||
	case *[]map[string]interface{}, []map[string]interface{}:
 | 
			
		||||
	case *[]map[string]interface{}:
 | 
			
		||||
		columnTypes, _ := rows.ColumnTypes()
 | 
			
		||||
		for initialized || rows.Next() {
 | 
			
		||||
			prepareValues(values, db, columnTypes, columns)
 | 
			
		||||
@ -149,11 +153,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
 | 
			
		||||
			mapValue := map[string]interface{}{}
 | 
			
		||||
			scanIntoMap(mapValue, values, columns)
 | 
			
		||||
			if values, ok := dest.([]map[string]interface{}); ok {
 | 
			
		||||
				values = append(values, mapValue)
 | 
			
		||||
			} else if values, ok := dest.(*[]map[string]interface{}); ok {
 | 
			
		||||
				*values = append(*values, mapValue)
 | 
			
		||||
			}
 | 
			
		||||
			*dest = append(*dest, mapValue)
 | 
			
		||||
		}
 | 
			
		||||
	case *int, *int8, *int16, *int32, *int64,
 | 
			
		||||
		*uint, *uint8, *uint16, *uint32, *uint64, *uintptr,
 | 
			
		||||
@ -174,7 +174,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
 | 
			
		||||
			reflectValue = db.Statement.ReflectValue
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		if reflectValue.Kind() == reflect.Interface {
 | 
			
		||||
		for reflectValue.Kind() == reflect.Interface {
 | 
			
		||||
			reflectValue = reflectValue.Elem()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										119
									
								
								schema/field.go
									
									
									
									
									
								
							
							
						
						
									
										119
									
								
								schema/field.go
									
									
									
									
									
								
							@ -167,6 +167,12 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer {
 | 
			
		||||
		field.DataType = String
 | 
			
		||||
	} else if _, ok := field.TagSettings["SERIALIZER"]; ok {
 | 
			
		||||
		field.DataType = String
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok {
 | 
			
		||||
		field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64)
 | 
			
		||||
	}
 | 
			
		||||
@ -406,95 +412,52 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
 | 
			
		||||
	return field
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type GormFieldValuer interface {
 | 
			
		||||
	GormFieldValue(context.Context, *Field) (interface{}, bool)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// create valuer, setter when parse struct
 | 
			
		||||
func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
	// ValueOf
 | 
			
		||||
	switch {
 | 
			
		||||
	case len(field.StructField.Index) == 1:
 | 
			
		||||
		field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
 | 
			
		||||
			fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0])
 | 
			
		||||
			fv, zero := fieldValue.Interface(), fieldValue.IsZero()
 | 
			
		||||
			if vr, ok := fv.(GormFieldValuer); ok {
 | 
			
		||||
				fv, zero = vr.GormFieldValue(ctx, field)
 | 
			
		||||
			}
 | 
			
		||||
			return fv, zero
 | 
			
		||||
		}
 | 
			
		||||
	case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0:
 | 
			
		||||
		field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
 | 
			
		||||
			fieldValue := reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
 | 
			
		||||
			fv, zero := fieldValue.Interface(), fieldValue.IsZero()
 | 
			
		||||
			if vr, ok := fv.(GormFieldValuer); ok {
 | 
			
		||||
				fv, zero = vr.GormFieldValue(ctx, field)
 | 
			
		||||
			}
 | 
			
		||||
			return fv, zero
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		field.ValueOf = func(ctx context.Context, value reflect.Value) (interface{}, bool) {
 | 
			
		||||
			v := reflect.Indirect(value)
 | 
			
		||||
	// ValueOf returns field's value and if it is zero
 | 
			
		||||
 | 
			
		||||
			for _, idx := range field.StructField.Index {
 | 
			
		||||
				if idx >= 0 {
 | 
			
		||||
					v = v.Field(idx)
 | 
			
		||||
	// if vr, ok := fv.(GormFieldValuer); ok {
 | 
			
		||||
	// 	fv, zero = vr.GormFieldValue(ctx, field)
 | 
			
		||||
	// }
 | 
			
		||||
	field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) {
 | 
			
		||||
		v = reflect.Indirect(v)
 | 
			
		||||
		for _, fieldIdx := range field.StructField.Index {
 | 
			
		||||
			if fieldIdx >= 0 {
 | 
			
		||||
				v = v.Field(fieldIdx)
 | 
			
		||||
			} else {
 | 
			
		||||
				v = v.Field(-fieldIdx - 1)
 | 
			
		||||
 | 
			
		||||
				if !v.IsNil() {
 | 
			
		||||
					v = v.Elem()
 | 
			
		||||
				} else {
 | 
			
		||||
					v = v.Field(-idx - 1)
 | 
			
		||||
 | 
			
		||||
					if v.Type().Elem().Kind() != reflect.Struct {
 | 
			
		||||
						return nil, true
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if !v.IsNil() {
 | 
			
		||||
						v = v.Elem()
 | 
			
		||||
					} else {
 | 
			
		||||
						return nil, true
 | 
			
		||||
					}
 | 
			
		||||
					return nil, true
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			fv, zero := v.Interface(), v.IsZero()
 | 
			
		||||
			if vr, ok := fv.(GormFieldValuer); ok {
 | 
			
		||||
				fv, zero = vr.GormFieldValue(ctx, field)
 | 
			
		||||
			}
 | 
			
		||||
			return fv, zero
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fv, zero := v.Interface(), v.IsZero()
 | 
			
		||||
		return fv, zero
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ReflectValueOf
 | 
			
		||||
	switch {
 | 
			
		||||
	case len(field.StructField.Index) == 1:
 | 
			
		||||
		field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
 | 
			
		||||
			return reflect.Indirect(value).Field(field.StructField.Index[0])
 | 
			
		||||
		}
 | 
			
		||||
	case len(field.StructField.Index) == 2 && field.StructField.Index[0] >= 0 && field.FieldType.Kind() != reflect.Ptr:
 | 
			
		||||
		field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
 | 
			
		||||
			return reflect.Indirect(value).Field(field.StructField.Index[0]).Field(field.StructField.Index[1])
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		field.ReflectValueOf = func(ctx context.Context, value reflect.Value) reflect.Value {
 | 
			
		||||
			v := reflect.Indirect(value)
 | 
			
		||||
			for idx, fieldIdx := range field.StructField.Index {
 | 
			
		||||
				if fieldIdx >= 0 {
 | 
			
		||||
					v = v.Field(fieldIdx)
 | 
			
		||||
				} else {
 | 
			
		||||
					v = v.Field(-fieldIdx - 1)
 | 
			
		||||
	// ReflectValueOf returns field's reflect value
 | 
			
		||||
	field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value {
 | 
			
		||||
		v = reflect.Indirect(v)
 | 
			
		||||
		for idx, fieldIdx := range field.StructField.Index {
 | 
			
		||||
			if fieldIdx >= 0 {
 | 
			
		||||
				v = v.Field(fieldIdx)
 | 
			
		||||
			} else {
 | 
			
		||||
				v = v.Field(-fieldIdx - 1)
 | 
			
		||||
 | 
			
		||||
				if v.IsNil() {
 | 
			
		||||
					v.Set(reflect.New(v.Type().Elem()))
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if v.Kind() == reflect.Ptr {
 | 
			
		||||
					if v.Type().Elem().Kind() == reflect.Struct {
 | 
			
		||||
						if v.IsNil() {
 | 
			
		||||
							v.Set(reflect.New(v.Type().Elem()))
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					if idx < len(field.StructField.Index)-1 {
 | 
			
		||||
						v = v.Elem()
 | 
			
		||||
					}
 | 
			
		||||
				if idx < len(field.StructField.Index)-1 {
 | 
			
		||||
					v = v.Elem()
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return v
 | 
			
		||||
		}
 | 
			
		||||
		return v
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) {
 | 
			
		||||
@ -565,11 +528,7 @@ func (field *Field) setupValuerAndSetter() {
 | 
			
		||||
					field.ReflectValueOf(ctx, value).SetBool(false)
 | 
			
		||||
				}
 | 
			
		||||
			case int64:
 | 
			
		||||
				if data > 0 {
 | 
			
		||||
					field.ReflectValueOf(ctx, value).SetBool(true)
 | 
			
		||||
				} else {
 | 
			
		||||
					field.ReflectValueOf(ctx, value).SetBool(false)
 | 
			
		||||
				}
 | 
			
		||||
				field.ReflectValueOf(ctx, value).SetBool(data > 0)
 | 
			
		||||
			case string:
 | 
			
		||||
				b, _ := strconv.ParseBool(data)
 | 
			
		||||
				field.ReflectValueOf(ctx, value).SetBool(b)
 | 
			
		||||
 | 
			
		||||
@ -123,7 +123,7 @@ func TestCreateFromMap(t *testing.T) {
 | 
			
		||||
		{"name": "create_from_map_3", "Age": 20},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := DB.Model(&User{}).Create(datas).Error; err != nil {
 | 
			
		||||
	if err := DB.Model(&User{}).Create(&datas).Error; err != nil {
 | 
			
		||||
		t.Fatalf("failed to create data from slice of map, got error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user