Fix can't scan null value into normal data types
This commit is contained in:
		
							parent
							
								
									2218e32999
								
							
						
					
					
						commit
						94685d1024
					
				| @ -168,7 +168,7 @@ func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { | ||||
| 
 | ||||
| 		return tx.Create(dest) | ||||
| 	} else if len(tx.Statement.assigns) > 0 { | ||||
| 		exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]) | ||||
| 		exprs := tx.Statement.BuildCondtion(tx.Statement.assigns[0], tx.Statement.assigns[1:]...) | ||||
| 		assigns := map[string]interface{}{} | ||||
| 		for _, expr := range exprs { | ||||
| 			if eq, ok := expr.(clause.Eq); ok { | ||||
|  | ||||
							
								
								
									
										158
									
								
								scan.go
									
									
									
									
									
								
							
							
						
						
									
										158
									
								
								scan.go
									
									
									
									
									
								
							| @ -14,40 +14,53 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 
 | ||||
| 	switch dest := db.Statement.Dest.(type) { | ||||
| 	case map[string]interface{}, *map[string]interface{}: | ||||
| 		for idx, _ := range columns { | ||||
| 			values[idx] = new(interface{}) | ||||
| 		} | ||||
| 
 | ||||
| 		if initialized || rows.Next() { | ||||
| 			for idx := range columns { | ||||
| 				values[idx] = new(interface{}) | ||||
| 			} | ||||
| 
 | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(values...)) | ||||
| 		} | ||||
| 
 | ||||
| 		mapValue, ok := dest.(map[string]interface{}) | ||||
| 		if ok { | ||||
| 			if v, ok := dest.(*map[string]interface{}); ok { | ||||
| 				mapValue = *v | ||||
| 			mapValue, ok := dest.(map[string]interface{}) | ||||
| 			if !ok { | ||||
| 				if v, ok := dest.(*map[string]interface{}); ok { | ||||
| 					mapValue = *v | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			for idx, column := range columns { | ||||
| 				if v, ok := values[idx].(*interface{}); ok { | ||||
| 					if v == nil { | ||||
| 						mapValue[column] = nil | ||||
| 					} else { | ||||
| 						mapValue[column] = *v | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 
 | ||||
| 		for idx, column := range columns { | ||||
| 			mapValue[column] = *(values[idx].(*interface{})) | ||||
| 		} | ||||
| 	case *[]map[string]interface{}: | ||||
| 		for idx, _ := range columns { | ||||
| 			values[idx] = new(interface{}) | ||||
| 		} | ||||
| 
 | ||||
| 		for initialized || rows.Next() { | ||||
| 			for idx := range columns { | ||||
| 				values[idx] = new(interface{}) | ||||
| 			} | ||||
| 
 | ||||
| 			initialized = false | ||||
| 			db.RowsAffected++ | ||||
| 			db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 			v := map[string]interface{}{} | ||||
| 			mapValue := map[string]interface{}{} | ||||
| 			for idx, column := range columns { | ||||
| 				v[column] = *(values[idx].(*interface{})) | ||||
| 				if v, ok := values[idx].(*interface{}); ok { | ||||
| 					if v == nil { | ||||
| 						mapValue[column] = nil | ||||
| 					} else { | ||||
| 						mapValue[column] = *v | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 			*dest = append(*dest, v) | ||||
| 
 | ||||
| 			*dest = append(*dest, mapValue) | ||||
| 		} | ||||
| 	case *int, *int64, *uint, *uint64: | ||||
| 		for initialized || rows.Next() { | ||||
| @ -85,28 +98,52 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 			} | ||||
| 
 | ||||
| 			for initialized || rows.Next() { | ||||
| 				for idx := range columns { | ||||
| 					values[idx] = new(interface{}) | ||||
| 				} | ||||
| 
 | ||||
| 				initialized = false | ||||
| 				db.RowsAffected++ | ||||
| 
 | ||||
| 				elem := reflect.New(reflectValueType).Elem() | ||||
| 
 | ||||
| 				if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 { | ||||
| 					// pluck
 | ||||
| 					values[0] = elem.Addr().Interface() | ||||
| 					db.AddError(rows.Scan(values...)) | ||||
| 				} else { | ||||
| 					for idx, field := range fields { | ||||
| 						if field != nil { | ||||
| 							values[idx] = field.ReflectValueOf(elem).Addr().Interface() | ||||
| 						} else if joinFields[idx][0] != nil { | ||||
| 							relValue := joinFields[idx][0].ReflectValueOf(elem) | ||||
| 							if relValue.Kind() == reflect.Ptr && relValue.IsNil() { | ||||
| 								relValue.Set(reflect.New(relValue.Type().Elem())) | ||||
| 							} | ||||
| 					db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 							values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface() | ||||
| 					for idx, field := range fields { | ||||
| 						if v, ok := values[idx].(*interface{}); ok { | ||||
| 							if field != nil { | ||||
| 								if v == nil { | ||||
| 									field.Set(elem, v) | ||||
| 								} else { | ||||
| 									field.Set(elem, *v) | ||||
| 								} | ||||
| 							} else if joinFields[idx][0] != nil { | ||||
| 								relValue := joinFields[idx][0].ReflectValueOf(elem) | ||||
| 								if relValue.Kind() == reflect.Ptr && relValue.IsNil() { | ||||
| 									if v == nil { | ||||
| 										continue | ||||
| 									} | ||||
| 									relValue.Set(reflect.New(relValue.Type().Elem())) | ||||
| 								} | ||||
| 
 | ||||
| 								if v == nil { | ||||
| 									joinFields[idx][1].Set(relValue, nil) | ||||
| 								} else { | ||||
| 									joinFields[idx][1].Set(relValue, *v) | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				db.RowsAffected++ | ||||
| 				db.AddError(rows.Scan(values...)) | ||||
| 					for idx := range columns { | ||||
| 						values[idx] = new(interface{}) | ||||
| 					} | ||||
| 				} | ||||
| 
 | ||||
| 				if isPtr { | ||||
| 					db.Statement.ReflectValue.Set(reflect.Append(db.Statement.ReflectValue, elem.Addr())) | ||||
| @ -115,30 +152,45 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) { | ||||
| 				} | ||||
| 			} | ||||
| 		case reflect.Struct: | ||||
| 			for idx, column := range columns { | ||||
| 				if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 					values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() | ||||
| 				} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 					if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { | ||||
| 						relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) | ||||
| 						if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 							if relValue.Kind() == reflect.Ptr && relValue.IsNil() { | ||||
| 								relValue.Set(reflect.New(relValue.Type().Elem())) | ||||
| 							} | ||||
| 
 | ||||
| 							values[idx] = field.ReflectValueOf(relValue).Addr().Interface() | ||||
| 							continue | ||||
| 						} | ||||
| 					} | ||||
| 					values[idx] = &sql.RawBytes{} | ||||
| 				} else { | ||||
| 					values[idx] = &sql.RawBytes{} | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if initialized || rows.Next() { | ||||
| 				for idx := range columns { | ||||
| 					values[idx] = new(interface{}) | ||||
| 				} | ||||
| 
 | ||||
| 				db.RowsAffected++ | ||||
| 				db.AddError(rows.Scan(values...)) | ||||
| 
 | ||||
| 				for idx, column := range columns { | ||||
| 					if field := db.Statement.Schema.LookUpField(column); field != nil && field.Readable { | ||||
| 						if v, ok := values[idx].(*interface{}); ok { | ||||
| 							if v == nil { | ||||
| 								field.Set(db.Statement.ReflectValue, v) | ||||
| 							} else { | ||||
| 								field.Set(db.Statement.ReflectValue, *v) | ||||
| 							} | ||||
| 						} | ||||
| 					} else if names := strings.Split(column, "__"); len(names) > 1 { | ||||
| 						if rel, ok := db.Statement.Schema.Relationships.Relations[names[0]]; ok { | ||||
| 							relValue := rel.Field.ReflectValueOf(db.Statement.ReflectValue) | ||||
| 							if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { | ||||
| 								if v, ok := values[idx].(*interface{}); ok { | ||||
| 									if relValue.Kind() == reflect.Ptr && relValue.IsNil() { | ||||
| 										if v == nil { | ||||
| 											continue | ||||
| 										} | ||||
| 										relValue.Set(reflect.New(relValue.Type().Elem())) | ||||
| 									} | ||||
| 
 | ||||
| 									if v == nil { | ||||
| 										field.Set(relValue, nil) | ||||
| 									} else { | ||||
| 										field.Set(relValue, *v) | ||||
| 									} | ||||
| 								} | ||||
| 							} | ||||
| 						} | ||||
| 					} | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
							
								
								
									
										121
									
								
								schema/field.go
									
									
									
									
									
								
							
							
						
						
									
										121
									
								
								schema/field.go
									
									
									
									
									
								
							| @ -402,34 +402,48 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	recoverFunc := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { | ||||
| 	fallbackSetter := func(value reflect.Value, v interface{}, setter func(reflect.Value, interface{}) error) (err error) { | ||||
| 		if v == nil { | ||||
| 			field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) | ||||
| 		} else { | ||||
| 			reflectV := reflect.ValueOf(v) | ||||
| 
 | ||||
| 			if reflectV.Type().ConvertibleTo(field.FieldType) { | ||||
| 			if reflectV.Type().AssignableTo(field.FieldType) { | ||||
| 				field.ReflectValueOf(value).Set(reflectV) | ||||
| 				return | ||||
| 			} else if reflectV.Type().ConvertibleTo(field.FieldType) { | ||||
| 				field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) | ||||
| 			} else if valuer, ok := v.(driver.Valuer); ok { | ||||
| 				if v, err = valuer.Value(); err == nil { | ||||
| 					return setter(value, v) | ||||
| 				} | ||||
| 			} else if field.FieldType.Kind() == reflect.Ptr && reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { | ||||
| 				return | ||||
| 			} else if field.FieldType.Kind() == reflect.Ptr { | ||||
| 				fieldValue := field.ReflectValueOf(value) | ||||
| 				if fieldValue.IsNil() { | ||||
| 					if v == nil { | ||||
| 						return nil | ||||
| 
 | ||||
| 				if reflectV.Type().AssignableTo(field.FieldType.Elem()) { | ||||
| 					if fieldValue.IsNil() { | ||||
| 						fieldValue.Set(reflect.New(field.FieldType.Elem())) | ||||
| 					} | ||||
| 					fieldValue.Set(reflect.New(field.FieldType.Elem())) | ||||
| 					fieldValue.Elem().Set(reflectV) | ||||
| 					return | ||||
| 				} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { | ||||
| 					if fieldValue.IsNil() { | ||||
| 						fieldValue.Set(reflect.New(field.FieldType.Elem())) | ||||
| 					} | ||||
| 
 | ||||
| 					fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) | ||||
| 					return | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			if valuer, ok := v.(driver.Valuer); ok { | ||||
| 				if v, err = valuer.Value(); err == nil { | ||||
| 					setter(value, v) | ||||
| 				} | ||||
| 				fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) | ||||
| 			} else if reflectV.Kind() == reflect.Ptr { | ||||
| 				return field.Set(value, reflectV.Elem().Interface()) | ||||
| 				setter(value, reflectV.Elem().Interface()) | ||||
| 			} else { | ||||
| 				return fmt.Errorf("failed to set value %+v to field %v", v, field.Name) | ||||
| 			} | ||||
| 		} | ||||
| 		return err | ||||
| 
 | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// Set
 | ||||
| @ -441,8 +455,17 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 				field.ReflectValueOf(value).SetBool(data) | ||||
| 			case *bool: | ||||
| 				field.ReflectValueOf(value).SetBool(*data) | ||||
| 			case int64: | ||||
| 				if data > 0 { | ||||
| 					field.ReflectValueOf(value).SetBool(true) | ||||
| 				} else { | ||||
| 					field.ReflectValueOf(value).SetBool(false) | ||||
| 				} | ||||
| 			case string: | ||||
| 				b, _ := strconv.ParseBool(data) | ||||
| 				field.ReflectValueOf(value).SetBool(b) | ||||
| 			default: | ||||
| 				return recoverFunc(value, v, field.Set) | ||||
| 				return fallbackSetter(value, v, field.Set) | ||||
| 			} | ||||
| 			return nil | ||||
| 		} | ||||
| @ -498,7 +521,7 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 					field.ReflectValueOf(value).SetInt(0) | ||||
| 				} | ||||
| 			default: | ||||
| 				return recoverFunc(value, v, field.Set) | ||||
| 				return fallbackSetter(value, v, field.Set) | ||||
| 			} | ||||
| 			return err | ||||
| 		} | ||||
| @ -538,7 +561,7 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 					return err | ||||
| 				} | ||||
| 			default: | ||||
| 				return recoverFunc(value, v, field.Set) | ||||
| 				return fallbackSetter(value, v, field.Set) | ||||
| 			} | ||||
| 			return err | ||||
| 		} | ||||
| @ -578,7 +601,7 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 					return err | ||||
| 				} | ||||
| 			default: | ||||
| 				return recoverFunc(value, v, field.Set) | ||||
| 				return fallbackSetter(value, v, field.Set) | ||||
| 			} | ||||
| 			return err | ||||
| 		} | ||||
| @ -594,7 +617,7 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 			case float64, float32: | ||||
| 				field.ReflectValueOf(value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) | ||||
| 			default: | ||||
| 				return recoverFunc(value, v, field.Set) | ||||
| 				return fallbackSetter(value, v, field.Set) | ||||
| 			} | ||||
| 			return err | ||||
| 		} | ||||
| @ -615,7 +638,7 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) | ||||
| 					} | ||||
| 				default: | ||||
| 					return recoverFunc(value, v, field.Set) | ||||
| 					return fallbackSetter(value, v, field.Set) | ||||
| 				} | ||||
| 				return nil | ||||
| 			} | ||||
| @ -625,9 +648,6 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 				case time.Time: | ||||
| 					fieldValue := field.ReflectValueOf(value) | ||||
| 					if fieldValue.IsNil() { | ||||
| 						if v == nil { | ||||
| 							return nil | ||||
| 						} | ||||
| 						fieldValue.Set(reflect.New(field.FieldType.Elem())) | ||||
| 					} | ||||
| 					fieldValue.Elem().Set(reflect.ValueOf(v)) | ||||
| @ -647,7 +667,7 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 						return fmt.Errorf("failed to set string %v to time.Time field %v, failed to parse it as time, got error %v", v, field.Name, err) | ||||
| 					} | ||||
| 				default: | ||||
| 					return recoverFunc(value, v, field.Set) | ||||
| 					return fallbackSetter(value, v, field.Set) | ||||
| 				} | ||||
| 				return nil | ||||
| 			} | ||||
| @ -655,53 +675,42 @@ func (field *Field) setupValuerAndSetter() { | ||||
| 			if _, ok := fieldValue.Interface().(sql.Scanner); ok { | ||||
| 				// struct scanner
 | ||||
| 				field.Set = func(value reflect.Value, v interface{}) (err error) { | ||||
| 					if v == nil { | ||||
| 					if valuer, ok := v.(driver.Valuer); ok { | ||||
| 						v, _ = valuer.Value() | ||||
| 					} | ||||
| 
 | ||||
| 					reflectV := reflect.ValueOf(v) | ||||
| 					if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { | ||||
| 						field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) | ||||
| 					} else { | ||||
| 						reflectV := reflect.ValueOf(v) | ||||
| 						if reflectV.Type().ConvertibleTo(field.FieldType) { | ||||
| 							field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) | ||||
| 						} else if valuer, ok := v.(driver.Valuer); ok { | ||||
| 							if v, err = valuer.Value(); err == nil { | ||||
| 								err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) | ||||
| 							} | ||||
| 						} else { | ||||
| 							err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) | ||||
| 						} | ||||
| 						err = field.ReflectValueOf(value).Addr().Interface().(sql.Scanner).Scan(v) | ||||
| 					} | ||||
| 					return | ||||
| 				} | ||||
| 			} else if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { | ||||
| 				// pointer scanner
 | ||||
| 				field.Set = func(value reflect.Value, v interface{}) (err error) { | ||||
| 					if v == nil { | ||||
| 					if valuer, ok := v.(driver.Valuer); ok { | ||||
| 						v, _ = valuer.Value() | ||||
| 					} | ||||
| 
 | ||||
| 					reflectV := reflect.ValueOf(v) | ||||
| 					if reflectV.Type().ConvertibleTo(field.FieldType) { | ||||
| 						field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) | ||||
| 					} else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { | ||||
| 						field.ReflectValueOf(value).Set(reflect.New(field.FieldType).Elem()) | ||||
| 					} else { | ||||
| 						reflectV := reflect.ValueOf(v) | ||||
| 						if reflectV.Type().ConvertibleTo(field.FieldType) { | ||||
| 							field.ReflectValueOf(value).Set(reflectV.Convert(field.FieldType)) | ||||
| 						} else if reflectV.Type().ConvertibleTo(field.FieldType.Elem()) { | ||||
| 							fieldValue := field.ReflectValueOf(value) | ||||
| 							if fieldValue.IsNil() { | ||||
| 								if v == nil { | ||||
| 									return nil | ||||
| 								} | ||||
| 								fieldValue.Set(reflect.New(field.FieldType.Elem())) | ||||
| 							} | ||||
| 							fieldValue.Elem().Set(reflectV.Convert(field.FieldType.Elem())) | ||||
| 						} else if valuer, ok := v.(driver.Valuer); ok { | ||||
| 							if v, err = valuer.Value(); err == nil { | ||||
| 								err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) | ||||
| 							} | ||||
| 						} else { | ||||
| 							err = field.ReflectValueOf(value).Interface().(sql.Scanner).Scan(v) | ||||
| 						fieldValue := field.ReflectValueOf(value) | ||||
| 						if fieldValue.IsNil() { | ||||
| 							fieldValue.Set(reflect.New(field.FieldType.Elem())) | ||||
| 						} | ||||
| 						err = fieldValue.Interface().(sql.Scanner).Scan(v) | ||||
| 					} | ||||
| 					return | ||||
| 				} | ||||
| 			} else { | ||||
| 				field.Set = func(value reflect.Value, v interface{}) (err error) { | ||||
| 					return recoverFunc(value, v, field.Set) | ||||
| 					return fallbackSetter(value, v, field.Set) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
							
								
								
									
										12
									
								
								statement.go
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								statement.go
									
									
									
									
									
								
							| @ -146,8 +146,16 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { | ||||
| 		case clause.Column, clause.Table: | ||||
| 			stmt.QuoteTo(writer, v) | ||||
| 		case clause.Expr: | ||||
| 			writer.WriteString(v.SQL) | ||||
| 			stmt.Vars = append(stmt.Vars, v.Vars...) | ||||
| 			var varStr strings.Builder | ||||
| 			var sql = v.SQL | ||||
| 			for _, arg := range v.Vars { | ||||
| 				stmt.Vars = append(stmt.Vars, arg) | ||||
| 				stmt.DB.Dialector.BindVarTo(&varStr, stmt, arg) | ||||
| 				sql = strings.Replace(sql, "?", varStr.String(), 1) | ||||
| 				varStr.Reset() | ||||
| 			} | ||||
| 
 | ||||
| 			writer.WriteString(sql) | ||||
| 		case driver.Valuer: | ||||
| 			stmt.Vars = append(stmt.Vars, v) | ||||
| 			stmt.DB.Dialector.BindVarTo(writer, stmt, v) | ||||
|  | ||||
| @ -6,11 +6,6 @@ import ( | ||||
| 	. "gorm.io/gorm/utils/tests" | ||||
| ) | ||||
| 
 | ||||
| func TestMain(m *testing.M) { | ||||
| 	RunMigrations() | ||||
| 	m.Run() | ||||
| } | ||||
| 
 | ||||
| func TestExceptionsWithInvalidSql(t *testing.T) { | ||||
| 	var columns []string | ||||
| 	if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { | ||||
|  | ||||
| @ -1299,7 +1299,6 @@ func TestNilPointerSlice(t *testing.T) { | ||||
| 	) | ||||
| 
 | ||||
| 	DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) | ||||
| 
 | ||||
| 	if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { | ||||
| 		t.Error(err) | ||||
| 	} | ||||
|  | ||||
| @ -435,3 +435,35 @@ func TestSubQueryWithHaving(t *testing.T) { | ||||
| 		t.Errorf("Two user group should be found, instead found %d", len(results)) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestScanNullValue(t *testing.T) { | ||||
| 	user := GetUser("scan_null_value", Config{}) | ||||
| 	DB.Create(&user) | ||||
| 
 | ||||
| 	if err := DB.Model(&user).Update("age", nil).Error; err != nil { | ||||
| 		t.Fatalf("failed to update column age for struct, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var result User | ||||
| 	if err := DB.First(&result, "id = ?", user.ID).Error; err != nil { | ||||
| 		t.Fatalf("failed to query struct data with null age, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	AssertEqual(t, result, user) | ||||
| 
 | ||||
| 	users := []User{ | ||||
| 		*GetUser("scan_null_value_for_slice_1", Config{}), | ||||
| 		*GetUser("scan_null_value_for_slice_2", Config{}), | ||||
| 		*GetUser("scan_null_value_for_slice_3", Config{}), | ||||
| 	} | ||||
| 	DB.Create(&users) | ||||
| 
 | ||||
| 	if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil { | ||||
| 		t.Fatalf("failed to update column age for struct, got error %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	var results []User | ||||
| 	if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil { | ||||
| 		t.Fatalf("failed to query slice data with null age, got error %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @ -1,4 +1,4 @@ | ||||
| dialects=("sqlite" "mysql" "postgres" "mssql") | ||||
| dialects=("sqlite" "mysql" "postgres" "sqlserver") | ||||
| 
 | ||||
| if [[ $(pwd) == *"gorm/tests"* ]]; then | ||||
|   cd .. | ||||
| @ -10,7 +10,7 @@ for dialect in "${dialects[@]}" ; do | ||||
|     echo "testing ${dialect}..." | ||||
| 
 | ||||
|     race="" | ||||
|     if [ "$GORM_VERBOSE" = "" ] | ||||
|     if [ "$GORM_DIALECT" = "sqlserver" ] | ||||
|     then | ||||
|       race="-race" | ||||
|     fi | ||||
|  | ||||
| @ -23,6 +23,8 @@ func init() { | ||||
| 	if DB, err = OpenTestConnection(); err != nil { | ||||
| 		log.Printf("failed to connect database, got error %v\n", err) | ||||
| 		os.Exit(1) | ||||
| 	} else { | ||||
| 		RunMigrations() | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -155,12 +155,14 @@ func TestUpdates(t *testing.T) { | ||||
| 	AssertEqual(t, user2.UpdatedAt, user3.UpdatedAt) | ||||
| 
 | ||||
| 	// update with gorm exprs
 | ||||
| 	DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}) | ||||
| 	if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { | ||||
| 		t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) | ||||
| 	} | ||||
| 	var user4 User | ||||
| 	DB.First(&user4, user3.ID) | ||||
| 
 | ||||
| 	user3.Age += 100 | ||||
| 	AssertEqual(t, user4.UpdatedAt, user3.UpdatedAt) | ||||
| 	AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") | ||||
| } | ||||
| 
 | ||||
| func TestUpdateColumn(t *testing.T) { | ||||
|  | ||||
| @ -121,6 +121,11 @@ func TestFindOrCreate(t *testing.T) { | ||||
| 
 | ||||
| 	updatedAt1 := user4.UpdatedAt | ||||
| 	DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) | ||||
| 
 | ||||
| 	if user4.Age != 55 { | ||||
| 		t.Errorf("Failed to set change to 55, got %v", user4.Age) | ||||
| 	} | ||||
| 
 | ||||
| 	if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { | ||||
| 		t.Errorf("UpdateAt should be changed when update values with assign") | ||||
| 	} | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jinzhu
						Jinzhu